您的位置:首頁(yè) > 軟件教程 > 教程 > 強(qiáng)化學(xué)習(xí)筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】

強(qiáng)化學(xué)習(xí)筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】

來(lái)源:好特整理 | 時(shí)間:2024-10-18 09:46:01 | 閱讀:130 |  標(biāo)簽: a T CTO AWA Ri rop Pyre S C ICY Causality AR   | 分享到:

2024年ICML文章,ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization精讀

強(qiáng)化學(xué)習(xí)筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】


前言:


該論文是清華項(xiàng)目組組內(nèi)博士師兄寫(xiě)的文章,項(xiàng)目主頁(yè)為 ACE (ace-rl.github.io) ,于2024年7月發(fā)表在ICML期刊

因?yàn)樽罱M內(nèi)(其實(shí)只有我)需要從零開(kāi)始做一個(gè)相關(guān)項(xiàng)目,前面的幾篇文章都是鋪墊

本文章為強(qiáng)化學(xué)習(xí)筆記第5篇

本文初編輯于2024.10.5,好像是這個(gè)時(shí)間,忘記了,前后寫(xiě)了兩個(gè)多星期

CSDN主頁(yè): https://blog.csdn.net/rvdgdsva

博客園主頁(yè): https://www.cnblogs.com/hassle

博客園本文鏈接:


論文一覽

這篇強(qiáng)化學(xué)習(xí)論文主要介紹了一個(gè)名為 ACE 的算法,完整名稱為 Off-Policy Actor-Critic with Causality-Aware Entropy Regularization ,它通過(guò)引入因果關(guān)系分析和因果熵正則化來(lái)解決現(xiàn)有模型在不同動(dòng)作維度上的不平等探索問(wèn)題,旨在改進(jìn)強(qiáng)化學(xué)習(xí)【注釋1】中探索效率和樣本效率的問(wèn)題,特別是在高維度連續(xù)控制任務(wù)中的表現(xiàn)。

【注釋1】: 強(qiáng)化學(xué)習(xí)入門(mén)這一篇就夠了


論文摘要

在policy【注釋2】學(xué)習(xí)過(guò)程中,不同原始行為的不同意義被先前的model-free RL 算法所忽視。利用這一見(jiàn)解,我們探索了不同行動(dòng)維度和獎(jiǎng)勵(lì)之間的因果關(guān)系,以評(píng)估訓(xùn)練過(guò)程中各種原始行為的重要性。我們引入了一個(gè)因果關(guān)系感知熵【注釋3】項(xiàng)(causality-aware entropy term),它可以有效地識(shí)別并優(yōu)先考慮具有高潛在影響的行為,以實(shí)現(xiàn)高效的探索。此外,為了防止過(guò)度關(guān)注特定的原始行為,我們分析了梯度休眠現(xiàn)象(gradientdormancyphenomenon),并引入了休眠引導(dǎo)的重置機(jī)制,以進(jìn)一步增強(qiáng)我們方法的有效性。與無(wú)模型RL基線相比,我們提出的算法 ACE :Off-policy A ctor-criticwith C ausality-aware E ntropyregularization。在跨越7個(gè)域的29種不同連續(xù)控制任務(wù)中顯示出實(shí)質(zhì)性的性能優(yōu)勢(shì),這強(qiáng)調(diào)了我們方法的有效性、多功能性和高效的樣本效率。 基準(zhǔn)測(cè)試結(jié)果和視頻可在https://ace-rl.github.io/上獲得。

【注釋2】: 強(qiáng)化學(xué)習(xí)算法中on-policy和off-policy

【注釋3】: 最大熵 RL:從Soft Q-Learning到SAC - 知乎


論文主要貢獻(xiàn):

【1】 因果關(guān)系分析 :通過(guò)引入因果政策-獎(jiǎng)勵(lì)結(jié)構(gòu)模型,評(píng)估不同動(dòng)作維度(即原始行為)對(duì)獎(jiǎng)勵(lì)的影響大小(稱為“因果權(quán)重”)。這些權(quán)重反映了每個(gè)動(dòng)作維度在不同學(xué)習(xí)階段的相對(duì)重要性。

作出上述改進(jìn)的原因是:考慮一個(gè)簡(jiǎn)單的例子,一個(gè)機(jī)械手最初應(yīng)該學(xué)習(xí)放下手臂并抓住物體,然后將注意力轉(zhuǎn)移到學(xué)習(xí)手臂朝著最終目標(biāo)的運(yùn)動(dòng)方向上。因此,在策略學(xué)習(xí)的不同階段強(qiáng)調(diào)對(duì)最重要的原始行為的探索是 至關(guān)重要的。在探索過(guò)程中刻意關(guān)注各種原始行為,可以加速智能體在每個(gè)階段對(duì)基本原始行為的學(xué)習(xí),從而提高掌握完整運(yùn)動(dòng)任務(wù)的效率。

此處可供學(xué)習(xí)的資料:

【2】 因果熵正則化 :在最大熵強(qiáng)化學(xué)習(xí)框架的基礎(chǔ)上(如SAC算法),加入了 因果加權(quán)的熵正則化項(xiàng) 。與傳統(tǒng)熵正則化不同,這一項(xiàng)根據(jù)各個(gè)原始行為的因果權(quán)重動(dòng)態(tài)調(diào)整,強(qiáng)化對(duì)重要行為的探索,減少對(duì)不重要行為的探索。

作出上述改進(jìn)的原因是:論文引入了一個(gè)因果策略-獎(jiǎng)勵(lì)結(jié)構(gòu)模型來(lái)計(jì)算行動(dòng)空間上的因果權(quán)重(causal weights),因果權(quán)重會(huì)引導(dǎo)agent進(jìn)行更有效的探索, 鼓勵(lì)對(duì)因果權(quán)重較大的動(dòng)作維度進(jìn)行探索,表明對(duì)獎(jiǎng)勵(lì)的重要性更大,并減少對(duì)因果權(quán)重較小的行為維度的探 索。一般的最大熵目標(biāo)缺乏對(duì)不同學(xué)習(xí)階段原始行為之間區(qū)別的重要性的認(rèn)識(shí),可能導(dǎo)致低效的探索。為了解決這一限制,論文引入了一個(gè)由因果權(quán)重加權(quán)的策略熵作為因果關(guān)系感知的熵最大化目標(biāo),有效地加強(qiáng)了對(duì)重要原始行為的探索,并導(dǎo)致了更有效的探索。

此處可供學(xué)習(xí)的資料:

【3】 梯度“休眠”現(xiàn)象(Gradient Dormancy) :論文觀察到,模型訓(xùn)練時(shí)有些梯度會(huì)在某些階段不活躍(即“休眠”)。為了防止模型過(guò)度關(guān)注某些原始行為,論文引入了 梯度休眠導(dǎo)向的重置機(jī)制 。該機(jī)制通過(guò)周期性地對(duì)模型進(jìn)行擾動(dòng)(reset),避免模型陷入局部最優(yōu),促進(jìn)更廣泛的探索。

作出上述改進(jìn)的原因是:該機(jī)制通過(guò)一個(gè)由梯度休眠程度決定的因素間歇性地干擾智能體的神經(jīng)網(wǎng)絡(luò)。將因果關(guān)系感知探索與這種新穎的重置機(jī)制相結(jié)合,旨在促進(jìn)更高效、更有效的探索,最終提高智能體的整體性能。

通過(guò)在多個(gè)連續(xù)控制任務(wù)中的實(shí)驗(yàn),ACE 展示出了顯著優(yōu)于主流強(qiáng)化學(xué)習(xí)算法(如SAC、TD3)的表現(xiàn):

  • 29個(gè)不同的連續(xù)控制任務(wù) :包括 Meta-World(12個(gè)任務(wù))、DMControl(5個(gè)任務(wù))、Dexterous Hand(3個(gè)任務(wù))和其他稀疏獎(jiǎng)勵(lì)任務(wù)(6個(gè)任務(wù))。
  • 實(shí)驗(yàn)結(jié)果 表明,ACE 在所有任務(wù)中都達(dá)到了更好的樣本效率和更高的最終性能。例如,在復(fù)雜的稀疏獎(jiǎng)勵(lì)場(chǎng)景中,ACE 憑借其因果權(quán)重引導(dǎo)的探索策略,顯著超越了 SAC 和 TD3 等現(xiàn)有算法。

論文中的對(duì)比實(shí)驗(yàn)圖表顯示了 ACE 在多種任務(wù)下的顯著優(yōu)勢(shì),尤其是在 稀疏獎(jiǎng)勵(lì)和高維度任務(wù) 中,ACE 憑借其探索效率的提升,能更快達(dá)到最優(yōu)策略。


論文代碼框架

在ACE原論文的第21頁(yè),這玩意兒應(yīng)該寫(xiě)在正篇的,害的我看了好久的代碼去排流程

不過(guò)說(shuō)實(shí)話這偽代碼有夠簡(jiǎn)潔的,代碼多少有點(diǎn)糊成一坨了

這是一個(gè)強(qiáng)化學(xué)習(xí)(RL)算法的框架,具體是一個(gè)結(jié)合因果推斷(Causal Discovery)的離策略(Off-policy)Actor-Critic方法。下面是對(duì)每個(gè)模塊及其參數(shù)的說(shuō)明:

1. 初始化模塊

  • Q網(wǎng)絡(luò) ( \(Q_\phi\) ) :用于估計(jì)動(dòng)作價(jià)值,(\phi) 是權(quán)重參數(shù)。
  • 策略網(wǎng)絡(luò) ( $\pi_\theta $) :用于生成動(dòng)作策略,(\theta) 是其權(quán)重。
  • 重放緩沖區(qū) ($ D$ ) :存儲(chǔ)環(huán)境交互的數(shù)據(jù),以便進(jìn)行采樣。
  • 局部緩沖區(qū) ( $D_c $) :存儲(chǔ)因果發(fā)現(xiàn)所需的局部數(shù)據(jù)。
  • 因果權(quán)重矩陣 ($ B_{a \rightarrow r|s} $) :用于捕捉動(dòng)作與獎(jiǎng)勵(lì)之間的因果關(guān)系。
  • 擾動(dòng)因子 ( \(f\) ) :用于對(duì)策略進(jìn)行微小擾動(dòng),增加探索。

2. 因果發(fā)現(xiàn)模塊

  • 每 ( $$I$$ ) 步更新
    • 樣本采樣 :從局部緩沖區(qū) ( \(D_c\) ) 中抽樣 ( \(N_c\) ) 條轉(zhuǎn)移。
    • 更新因果權(quán)重矩陣 :調(diào)整 ($ B_{a \rightarrow r|s}$ ),用于反映當(dāng)前策略和獎(jiǎng)勵(lì)之間的因果關(guān)系。

3. 策略優(yōu)化模塊

  • 每個(gè)梯度步驟
    • 樣本采樣 :從重放緩沖區(qū) ( \(D\) ) 中抽樣 ($ N$ ) 條轉(zhuǎn)移。
    • 計(jì)算因果意識(shí)熵 ( \(H_c(\pi(\cdot|s))\) ) :衡量在給定狀態(tài)下策略的隨機(jī)性和確定性,用于修改策略。
    • 目標(biāo) Q 值計(jì)算 :更新目標(biāo) Q 值,用于訓(xùn)練 Q 網(wǎng)絡(luò)。
    • 更新 Q 網(wǎng)絡(luò) :減少預(yù)測(cè)的 Q 值與目標(biāo) Q 值之間的誤差。
    • 更新策略網(wǎng)絡(luò) :最大化當(dāng)前狀態(tài)下的 Q 值,以提高收益。

4. 重置機(jī)制模塊

  • 每個(gè)重置間隔
    • 計(jì)算梯度主導(dǎo)度 ( $\beta_\gamma $) :用來(lái)量化策略更新的影響程度。
    • 初始化隨機(jī)網(wǎng)絡(luò) :為新的策略更新準(zhǔn)備初始權(quán)重 ( $\phi_i $)。
    • 軟重置策略和 Q 網(wǎng)絡(luò) :根據(jù)因果權(quán)重進(jìn)行平滑更新,幫助實(shí)現(xiàn)更穩(wěn)定的優(yōu)化。
    • 重置策略和 Q 優(yōu)化器 :在重置時(shí)清空狀態(tài),以便進(jìn)行新的學(xué)習(xí)過(guò)程。

論文源代碼主干

源代碼上千行呢,這里只是貼上main_casual里面的部分代碼,并且刪掉了很大一部分代碼以便理清程序脈絡(luò)

def train_loop(config, msg = "default"):
    # Agent
    agent = ACE_agent(env.observation_space.shape[0], env.action_space, config)

    memory = ReplayMemory(config.replay_size, config.seed)
    local_buffer = ReplayMemory(config.causal_sample_size, config.seed)

    for i_episode in itertools.count(1):
        done = False

        state = env.reset()
        while not done:
            if config.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state)  # Sample action from policy

            if len(memory) > config.batch_size:
                for i in range(config.updates_per_step):
                    #* Update parameters of causal weight
                    if (total_numsteps % config.causal_sample_interval == 0) and (len(local_buffer)>=config.causal_sample_size):
                        causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                        print("Current Causal Weight is: ",causal_weight)
                        
                    dormant_metrics = {}
                    # Update parameters of all the networks
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight,config.batch_size, updates)

                    updates += 1
            next_state, reward, done, info = env.step(action) # Step
            total_numsteps += 1
            episode_steps += 1
            episode_reward += reward

            #* Ignore the "done" signal if it comes from hitting the time horizon.
            if '_max_episode_steps' in dir(env):  
                mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            elif 'max_path_length' in dir(env):
                mask = 1 if episode_steps == env.max_path_length else float(not done)
            else: 
                mask = 1 if episode_steps == 1000 else float(not done)

            memory.push(state, action, reward, next_state, mask) # Append transition to memory
            local_buffer.push(state, action, reward, next_state, mask) # Append transition to local_buffer
            state = next_state

        if total_numsteps > config.num_steps:
            break

        # test agent
        if i_episode % config.eval_interval == 0 and config.eval is True:
            eval_reward_list = []
            for _  in range(config.eval_episodes):
                state = env.reset()
                episode_reward = []
                done = False
                while not done:
                    action = agent.select_action(state, evaluate=True)
                    next_state, reward, done, info = env.step(action)
                    state = next_state
                    episode_reward.append(reward)
                eval_reward_list.append(sum(episode_reward))

            avg_reward = np.average(eval_reward_list)
          
    env.close() 

代碼流程解釋

  1. 初始化 :

    • 通過(guò)配置文件 config 設(shè)置環(huán)境和隨機(jī)種子。
    • 使用 ACE_agent 初始化強(qiáng)化學(xué)習(xí)智能體,該智能體會(huì)在后續(xù)過(guò)程中學(xué)習(xí)如何在環(huán)境中行動(dòng)。
    • 創(chuàng)建存儲(chǔ)結(jié)果和檢查點(diǎn)的目錄,確保訓(xùn)練過(guò)程中的配置和因果權(quán)重會(huì)被記錄下來(lái)。
    • 初始化了兩個(gè)重放緩沖區(qū): memory 用于存儲(chǔ)所有的歷史數(shù)據(jù), local_buffer 則用于因果權(quán)重的更新。
  2. 主訓(xùn)練循環(huán) :

    • 采樣動(dòng)作 :如果總步數(shù)較小,則從環(huán)境中隨機(jī)采樣動(dòng)作,否則從策略中選擇動(dòng)作。通過(guò)這種方式,確保早期探索和后期利用。

    • 更新因果權(quán)重 :在特定間隔內(nèi),從局部緩沖區(qū)中采樣數(shù)據(jù),通過(guò) get_sa2r_weight 函數(shù)使用DirectLiNGAM算法計(jì)算從動(dòng)作到獎(jiǎng)勵(lì)的因果權(quán)重。這個(gè)權(quán)重會(huì)作為額外信息,幫助智能體優(yōu)化策略。

    • 更新網(wǎng)絡(luò)參數(shù) :當(dāng) memory 中的數(shù)據(jù)足夠多時(shí),開(kāi)始通過(guò)采樣更新Q網(wǎng)絡(luò)和策略網(wǎng)絡(luò),使用計(jì)算出的因果權(quán)重來(lái)修正損失函數(shù)。

    • 記錄與保存模型 :每隔一定的步數(shù),算法會(huì)測(cè)試當(dāng)前策略的性能,記錄并比較獎(jiǎng)勵(lì)是否超過(guò)歷史最佳值,如果是,則保存模型的檢查點(diǎn)。

    • 使用 wandb 記錄訓(xùn)練過(guò)程中的指標(biāo),例如損失函數(shù)、獎(jiǎng)勵(lì)和因果權(quán)重的計(jì)算時(shí)間,這些信息可以幫助調(diào)試和分析訓(xùn)練過(guò)程。


論文模塊代碼及實(shí)現(xiàn)

因果發(fā)現(xiàn)模塊

因果發(fā)現(xiàn)模塊 主要通過(guò) get_sa2r_weight 函數(shù)實(shí)現(xiàn),并且與 DirectLiNGAM 模型結(jié)合,負(fù)責(zé)計(jì)算因果權(quán)重。具體代碼在訓(xùn)練循環(huán)中如下:

causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')

在這個(gè)代碼段, get_sa2r_weight 函數(shù)會(huì)基于當(dāng)前環(huán)境、樣本數(shù)據(jù)( local_buffer )和因果模型(這里使用的是 DirectLiNGAM ),計(jì)算與行動(dòng)相關(guān)的因果權(quán)重( causal_weight )。這些權(quán)重會(huì)影響后續(xù)的策略優(yōu)化和參數(shù)更新。關(guān)鍵邏輯包括:

  1. 采樣間隔 :因果發(fā)現(xiàn)是在 total_numsteps % config.causal_sample_interval == 0 時(shí)觸發(fā),確保只在指定的步數(shù)間隔內(nèi)計(jì)算因果權(quán)重,避免每一步都進(jìn)行因果計(jì)算,減輕計(jì)算負(fù)擔(dān)。
  2. 局部緩沖區(qū) local_buffer 中存儲(chǔ)了足夠的樣本( config.causal_sample_size ),這些樣本用于因果關(guān)系的發(fā)現(xiàn)。
  3. 因果方法 DirectLiNGAM 是選擇的因果模型,用于從狀態(tài)、行動(dòng)和獎(jiǎng)勵(lì)之間推導(dǎo)出因果關(guān)系。

因果權(quán)重計(jì)算完成后,程序會(huì)將這些權(quán)重應(yīng)用到策略優(yōu)化中,并且記錄權(quán)重及計(jì)算時(shí)間等信息。

def get_sa2r_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
    ······
    return weight, model._running_time

這個(gè)代碼的核心是利用DirectLiNGAM模型計(jì)算給定狀態(tài)、動(dòng)作和獎(jiǎng)勵(lì)之間的因果權(quán)重。接下來(lái),用LaTeX公式詳細(xì)表述計(jì)算因果權(quán)重的過(guò)程:

  1. 數(shù)據(jù)預(yù)處理
    將從 memory 中采樣的 states (狀態(tài))、 actions (動(dòng)作)和 rewards (獎(jiǎng)勵(lì))進(jìn)行拼接,構(gòu)建輸入數(shù)據(jù)矩陣 \(X_{\text{ori}}\)

    其中, \(S\) 代表狀態(tài), \(A\) 代表動(dòng)作, \(R\) 代表獎(jiǎng)勵(lì)。接著,構(gòu)建數(shù)據(jù)框 \(X\) 來(lái)進(jìn)行因果分析。

  2. 因果模型擬合

    X_ori 轉(zhuǎn)換為 X 是為了利用 pandas 數(shù)據(jù)框的便利性和靈活性

    使用 DirectLiNGAM 模型對(duì)矩陣 \(X\) 進(jìn)行擬合,得到因果關(guān)系的鄰接矩陣 \(A_{\text{model}}\)

    該鄰接矩陣表示狀態(tài)、動(dòng)作、獎(jiǎng)勵(lì)之間的因果結(jié)構(gòu),特別是從動(dòng)作到獎(jiǎng)勵(lì)的影響關(guān)系。

  3. 提取動(dòng)作對(duì)獎(jiǎng)勵(lì)的因果權(quán)重
    通過(guò)鄰接矩陣提取動(dòng)作對(duì)獎(jiǎng)勵(lì)的因果權(quán)重 \(w_{\text{r}}\) ,該權(quán)重從鄰接矩陣的最后一行中選擇與動(dòng)作對(duì)應(yīng)的元素:

    其中, \(d_s\) 是狀態(tài)的維度, \(d_a\) 是動(dòng)作的維度。

  4. 因果權(quán)重的歸一化
    對(duì)因果權(quán)重 \(w_{\text{r}}\) 進(jìn)行Softmax歸一化,確保它們的總和為1:

  5. 調(diào)整權(quán)重的尺度
    最后,因果權(quán)重根據(jù)動(dòng)作的數(shù)量進(jìn)行縮放:

最終輸出的權(quán)重 \(w\) 表示每個(gè)動(dòng)作對(duì)獎(jiǎng)勵(lì)的因果影響,經(jīng)過(guò)歸一化和縮放處理,可以用于進(jìn)一步的策略調(diào)整或分析。

策略優(yōu)化模塊

以下是對(duì)函數(shù)工作原理的逐步解釋:

策略優(yōu)化模塊 主要由 agent.update_parameters 函數(shù)實(shí)現(xiàn)。 agent.update_parameters 這個(gè)函數(shù)的主要目的是在強(qiáng)化學(xué)習(xí)中更新策略 ( policy ) 和價(jià)值網(wǎng)絡(luò)(critic)的參數(shù),以提升智能體的性能。這個(gè)函數(shù)實(shí)現(xiàn)了一個(gè)基于軟演員評(píng)論家(SAC, Soft Actor-Critic)的更新機(jī)制,并且加入了因果權(quán)重與"休眠"神經(jīng)元(dormant neurons)的處理,以提高模型的魯棒性和穩(wěn)定性。

critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight, config.batch_size, updates)

通過(guò) agent.update_parameters 函數(shù),程序會(huì)更新以下幾個(gè)部分:

  1. Critic網(wǎng)絡(luò)(價(jià)值網(wǎng)絡(luò)) critic_1_loss critic_2_loss 分別是兩個(gè) Critic 網(wǎng)絡(luò)的損失,用于評(píng)估當(dāng)前策略的價(jià)值。
  2. Policy網(wǎng)絡(luò)(策略網(wǎng)絡(luò)) policy_loss 表示策略網(wǎng)絡(luò)的損失,用于優(yōu)化 agent 的行動(dòng)選擇。
  3. Entropy損失 ent_loss 用來(lái)調(diào)節(jié)策略的隨機(jī)性,幫助 agent 在探索和利用之間找到平衡。
  4. Alpha :表示自適應(yīng)的熵系數(shù),用于調(diào)整探索與利用之間的權(quán)衡。

這些參數(shù)的更新在每次訓(xùn)練循環(huán)中被調(diào)用,并使用 wandb.log 記錄損失和其他相關(guān)的訓(xùn)練數(shù)據(jù)。

update_parameters ACE_agent 類中的一個(gè)關(guān)鍵函數(shù),用于根據(jù)經(jīng)驗(yàn)回放緩沖區(qū)中的樣本數(shù)據(jù)來(lái)更新模型的參數(shù)。下面是對(duì)其工作原理的詳細(xì)解釋:

1. 采樣經(jīng)驗(yàn)數(shù)據(jù)

首先,函數(shù)從 memory 中采樣一批樣本( state_batch 、 action_batch reward_batch 、 next_state_batch 、 mask_batch ),其中包括狀態(tài)、動(dòng)作、獎(jiǎng)勵(lì)、下一個(gè)狀態(tài)以及掩碼,用于表示是否為終止?fàn)顟B(tài)。

state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
  • state_batch :當(dāng)前的狀態(tài)。
  • action_batch :在當(dāng)前狀態(tài)下執(zhí)行的動(dòng)作。
  • reward_batch :執(zhí)行該動(dòng)作后獲得的獎(jiǎng)勵(lì)。
  • next_state_batch :執(zhí)行動(dòng)作后到達(dá)的下一個(gè)狀態(tài)。
  • mask_batch :掩碼,用于表示是否為終止?fàn)顟B(tài)(1 表示非終止,0 表示終止)。

2. 計(jì)算目標(biāo) Q 值

利用當(dāng)前策略(policy)網(wǎng)絡(luò),采樣下一個(gè)狀態(tài)的動(dòng)作 next_state_action 和其對(duì)應(yīng)的概率分布對(duì)數(shù) next_state_log_pi 。然后利用目標(biāo) Q 網(wǎng)絡(luò) critic_target 估計(jì)下一時(shí)刻的最小 Q 值,并結(jié)合獎(jiǎng)勵(lì)和折扣因子 \(\gamma\) 計(jì)算下一個(gè) Q 值:

with torch.no_grad():
    next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch, causal_weight)
    qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
    next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
  • 通過(guò)策略網(wǎng)絡(luò) self.policy 為下一個(gè)狀態(tài) next_state_batch 采樣動(dòng)作 next_state_action 和相應(yīng)的策略熵 next_state_log_pi

  • 使用目標(biāo) Q 網(wǎng)絡(luò)計(jì)算 qf1_next_target qf2_next_target ,并取兩者的最小值來(lái)減少估計(jì)偏差。

  • 最終使用貝爾曼方程計(jì)算 next_q_value ,即當(dāng)前的獎(jiǎng)勵(lì)加上折扣因子 \(\gamma\) 乘以下一個(gè)狀態(tài)的 Q 值。

  • 這里, \(\alpha\) 是熵項(xiàng)的權(quán)重,用于平衡探索和利用的權(quán)衡,而 mask_batch 是為了處理終止?fàn)顟B(tài)的情況。

    使用無(wú)偏估計(jì)來(lái)計(jì)算目標(biāo) Q 值。通過(guò)目標(biāo)網(wǎng)絡(luò) ( critic_target ) 計(jì)算出下一個(gè)狀態(tài)和動(dòng)作的 Q 值,并使用獎(jiǎng)勵(lì)和掩碼更新當(dāng)前 Q 值

3. 更新 Q 網(wǎng)絡(luò)

接著,使用當(dāng)前 Q 網(wǎng)絡(luò) critic 估計(jì)當(dāng)前狀態(tài)和動(dòng)作下的 Q 值 \(Q_1\) \(Q_2\) ,并計(jì)算它們與目標(biāo) Q 值的均方誤差損失:

最終 Q 網(wǎng)絡(luò)的總損失是兩個(gè) Q 網(wǎng)絡(luò)損失之和:

然后,通過(guò)反向傳播 qf_loss 來(lái)更新 Q 網(wǎng)絡(luò)的參數(shù)。

qf1, qf2 = self.critic(state_batch, action_batch)
qf1_loss = F.mse_loss(qf1, next_q_value)
qf2_loss = F.mse_loss(qf2, next_q_value)
qf_loss = qf1_loss + qf2_loss

self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
  • qf1 qf2 是兩個(gè) Q 網(wǎng)絡(luò)的輸出,用于減少正向估計(jì)偏差。
  • 損失函數(shù)是 Q 值的均方誤差(MSE), qf1_loss qf2_loss 分別計(jì)算兩個(gè) Q 網(wǎng)絡(luò)的誤差,最后將兩者相加為總的 Q 損失 qf_loss
  • 通過(guò) self.critic_optim 優(yōu)化器對(duì)損失進(jìn)行反向傳播和參數(shù)更新。

4. 策略網(wǎng)絡(luò)更新

每隔若干步(通過(guò) target_update_interval 控制),開(kāi)始更新策略網(wǎng)絡(luò) policy 。首先,重新采樣當(dāng)前狀態(tài)下的策略 \(\pi(a|s)\) ,并計(jì)算 Q 值和熵權(quán)重下的策略損失:

這個(gè)損失通過(guò)反向傳播更新策略網(wǎng)絡(luò)。

if updates % self.target_update_interval == 0:
    pi, log_pi, _ = self.policy.sample(state_batch, causal_weight)
    qf1_pi, qf2_pi = self.critic(state_batch, pi)
    min_qf_pi = torch.min(qf1_pi, qf2_pi)
    policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

    self.policy_optim.zero_grad()
    policy_loss.backward()
    self.policy_optim.step()
  • 通過(guò)策略網(wǎng)絡(luò)對(duì)當(dāng)前狀態(tài) state_batch 進(jìn)行采樣,得到動(dòng)作 pi 及其對(duì)應(yīng)的策略熵 log_pi 。
  • 計(jì)算策略損失 policy_loss ,即 \(\alpha\) 倍的策略熵減去最小的 Q 值。
  • 通過(guò) self.policy_optim 優(yōu)化器對(duì)策略損失進(jìn)行反向傳播和參數(shù)更新。

5. 自適應(yīng)熵調(diào)節(jié)

如果開(kāi)啟了自動(dòng)熵項(xiàng)調(diào)整( automatic_entropy_tuning ),則會(huì)進(jìn)一步更新熵項(xiàng) \(\alpha\) 的損失:

并通過(guò)梯度下降更新 \(\alpha\) 。

如果 automatic_entropy_tuning 為真,則會(huì)更新熵項(xiàng):

if self.automatic_entropy_tuning:
    alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
    self.alpha_optim.zero_grad()
    alpha_loss.backward()
    self.alpha_optim.step()
    self.alpha = self.log_alpha.exp()
    alpha_tlogs = self.alpha.clone()
else:
    alpha_loss = torch.tensor(0.).to(self.device)
    alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
  • 通過(guò)計(jì)算 alpha_loss 更新 self.alpha ,調(diào)整策略的探索-利用平衡。

6. 返回值

  • qf1_loss , qf2_loss : 兩個(gè) Q 網(wǎng)絡(luò)的損失
  • policy_loss : 策略網(wǎng)絡(luò)的損失
  • alpha_loss : 熵權(quán)重的損失
  • alpha_tlogs : 用于日志記錄的熵權(quán)重
  • next_q_value : 平均下一個(gè) Q 值
  • dormant_metrics : 休眠神經(jīng)元的相關(guān)度量

重置機(jī)制模塊

重置機(jī)制模塊在代碼中主要體現(xiàn)在 update_parameters 函數(shù)中,并通過(guò) 梯度主導(dǎo)度 (dominant metrics) 和 擾動(dòng)函數(shù) (perturbation functions) 實(shí)現(xiàn)對(duì)策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò)的重置。

重置邏輯

函數(shù)根據(jù)設(shè)定的 reset_interval 判斷是否需要對(duì)策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò)進(jìn)行擾動(dòng)和重置。這里使用了"休眠"神經(jīng)元的概念,即一些在梯度更新中影響較小的神經(jīng)元,可能會(huì)被調(diào)整或重置。

函數(shù)計(jì)算了休眠度量 dormant_metrics 和因果權(quán)重差異 causal_diff ,通過(guò)擾動(dòng)因子 perturb_factor 來(lái)決定是否對(duì)網(wǎng)絡(luò)進(jìn)行部分或全部的擾動(dòng)與重置。

重置機(jī)制模塊的原理

重置機(jī)制主要由以下部分組成:

1. 計(jì)算梯度主導(dǎo)度 ( $\beta_\gamma $)

在更新策略時(shí),計(jì)算 主導(dǎo)梯度 ,即某些特定神經(jīng)元或參數(shù)在更新中主導(dǎo)作用的比率。代碼中通過(guò)調(diào)用 cal_dormant_grad(self.policy, type='policy', percentage=0.05) 實(shí)現(xiàn)這一計(jì)算,代表提取出 5% 的主導(dǎo)梯度來(lái)作為判斷因子。

dormant_metrics = cal_dormant_grad(self.policy, type='policy', percentage=0.05)

根據(jù)主導(dǎo)度 ($ \beta_\gamma$ ) 和權(quán)重 ($ w$ ),可以得到因果效應(yīng)的差異。代碼里用 causal_diff 來(lái)表示因果差異:

2. 軟重置策略和 Q 網(wǎng)絡(luò)

軟重置機(jī)制通過(guò)平滑更新策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò),避免過(guò)大的權(quán)重更新導(dǎo)致的網(wǎng)絡(luò)不穩(wěn)定。這在代碼中由 soft_update 實(shí)現(xiàn):

soft_update(self.critic_target, self.critic, self.tau)

具體來(lái)說(shuō),軟更新的公式為:

其中,( \(\tau\) ) 是一個(gè)較小的常數(shù),通常介于 ( [0, 1] ) 之間,確保目標(biāo)網(wǎng)絡(luò)的更新是緩慢的,以提高學(xué)習(xí)的穩(wěn)定性。

3. 策略和 Q 優(yōu)化器的重置
4. 重置機(jī)制模塊的應(yīng)用

每當(dāng)經(jīng)過(guò)一定的重置間隔時(shí),判斷是否需要擾動(dòng)策略和 Q 網(wǎng)絡(luò)。通過(guò)調(diào)用 perturb() dormant_perturb() 實(shí)現(xiàn)對(duì)網(wǎng)絡(luò)的擾動(dòng)(perturbation)。擾動(dòng)因子由梯度主導(dǎo)度和因果差異共同決定。

策略與 Q 網(wǎng)絡(luò)的擾動(dòng)會(huì)在以下兩種情況下發(fā)生:

a. 重置間隔達(dá)成時(shí)

代碼中每當(dāng)更新次數(shù) updates 達(dá)到設(shè)定的重置間隔 self.reset_interval ,并且 updates > 5000 時(shí),才會(huì)觸發(fā)策略與 Q 網(wǎng)絡(luò)的重置邏輯。這是為了確保擾動(dòng)不是頻繁發(fā)生,而是在經(jīng)過(guò)一段較長(zhǎng)的訓(xùn)練時(shí)間后進(jìn)行。

具體判斷條件:

if updates % self.reset_interval == 0 and updates > 5000:
b. 主導(dǎo)梯度或因果效應(yīng)差異滿足條件時(shí)

在達(dá)到了重置間隔后,首先會(huì)計(jì)算 梯度主導(dǎo)度 因果效應(yīng)的差異 。這可以通過(guò)計(jì)算因果差異 causal_diff 或梯度主導(dǎo)度 dormant_metrics['policy_grad_dormant_ratio'] 來(lái)決定是否需要擾動(dòng)。

  • 梯度主導(dǎo)度 計(jì)算方式通過(guò) cal_dormant_grad() 函數(shù)實(shí)現(xiàn),如果梯度主導(dǎo)度較低,意味著網(wǎng)絡(luò)中的某些神經(jīng)元更新幅度過(guò)小,則需要對(duì)網(wǎng)絡(luò)進(jìn)行擾動(dòng)。

  • 因果效應(yīng)差異 通過(guò)計(jì)算 causal_diff = np.max(causal_weight) - np.min(causal_weight) 得到,如果差異過(guò)大,則可能需要重置。

然后根據(jù)這些值通過(guò)擾動(dòng)因子 factor 進(jìn)行判斷:

factor = perturb_factor(dormant_metrics['policy_grad_dormant_ratio'])

如果擾動(dòng)因子 ( \(\text{factor} < 1\) ),網(wǎng)絡(luò)會(huì)進(jìn)行擾動(dòng):

if factor < 1:
    if self.reset == 'reset' or self.reset == 'causal_reset':
        perturb(self.policy, self.policy_optim, factor)
        perturb(self.critic, self.critic_optim, factor)
        perturb(self.critic_target, self.critic_optim, factor)
c. 總結(jié)
  • 更新次數(shù)達(dá)到設(shè)定的重置間隔 ,且經(jīng)過(guò)了一定時(shí)間的訓(xùn)練( updates > 5000 )。
  • 梯度主導(dǎo)度 較低或 因果效應(yīng)差異 過(guò)大,導(dǎo)致計(jì)算出的擾動(dòng)因子 ( $\text{factor} < 1 $)。

這兩種條件同時(shí)滿足時(shí),策略和 Q 網(wǎng)絡(luò)將被擾動(dòng)或重置。

擾動(dòng)因子的計(jì)算

在這段代碼中, factor 是基于網(wǎng)絡(luò)中梯度主導(dǎo)度或者因果效應(yīng)差異計(jì)算出來(lái)的擾動(dòng)因子。擾動(dòng)因子通過(guò)函數(shù) perturb_factor() 進(jìn)行計(jì)算,該函數(shù)會(huì)根據(jù)神經(jīng)元的梯度主導(dǎo)度( dormant_ratio )或因果效應(yīng)差異( causal_diff )來(lái)調(diào)整 factor 的大小。

擾動(dòng)因子(factor)

擾動(dòng)因子 factor 的計(jì)算公式如下:

其中:

  • ( \(\text{dormant\_ratio}\) ) 是網(wǎng)絡(luò)中梯度主導(dǎo)度,即表示有多少神經(jīng)元的梯度變化較。ɑ蛘呓咏悖幱凇靶菝摺睜顟B(tài)。

  • ( \(\text{min\_perturb\_factor}\) ) 是最小擾動(dòng)因子值,代碼中設(shè)定為 0.2 。

  • ( \(\text{max\_perturb\_factor}\) ) 是最大擾動(dòng)因子值,代碼中設(shè)定為 0.9 。

  • dormant_ratio :

    • 表示網(wǎng)絡(luò)中處于“休眠狀態(tài)”的梯度比例。這個(gè)比例通常通過(guò)計(jì)算神經(jīng)網(wǎng)絡(luò)中梯度幅度小于某個(gè)閾值的神經(jīng)元數(shù)量來(lái)獲得。 dormant_ratio 越大,表示越多神經(jīng)元的梯度變化很小,說(shuō)明網(wǎng)絡(luò)更新不充分,需要擾動(dòng)。
  • max_perturb_factor :

    • 最大擾動(dòng)因子值,用來(lái)限制擾動(dòng)因子的上限,代碼中設(shè)定為 0.9,意味著最大擾動(dòng)幅度不會(huì)超過(guò) 90%。
  • min_perturb_factor :

    • 最小擾動(dòng)因子值,用來(lái)限制擾動(dòng)因子的下限,代碼中設(shè)定為 0.2,意味著即使休眠神經(jīng)元比例很低,擾動(dòng)幅度也不會(huì)小于 20%。

在計(jì)算因果效應(yīng)的部分,擾動(dòng)因子 factor 還會(huì)根據(jù)因果效應(yīng)差異 causal_diff 來(lái)調(diào)整。 causal_diff 是通過(guò)計(jì)算因果效應(yīng)的最大值與最小值的差異來(lái)獲得的:

計(jì)算出的 causal_diff 會(huì)影響 causal_factor ,并進(jìn)一步對(duì) factor 進(jìn)行調(diào)整:

組合擾動(dòng)因子的公式

最后,如果選擇了因果重置( causal_reset ),擾動(dòng)因子將使用因果差異計(jì)算出的 causal_factor 進(jìn)行二次調(diào)整:

綜上所述, factor 的最終值是由梯度主導(dǎo)度或因果效應(yīng)差異來(lái)控制的,當(dāng)休眠神經(jīng)元比例較大或因果效應(yīng)差異較大時(shí), factor 會(huì)減小,導(dǎo)致網(wǎng)絡(luò)進(jìn)行擾動(dòng)。

評(píng)估代碼

這段代碼主要實(shí)現(xiàn)了在強(qiáng)化學(xué)習(xí)(RL)訓(xùn)練過(guò)程中,定期評(píng)估智能體(agent)的性能,并在某些條件下保存最佳模型的檢查點(diǎn)。我們可以分段解釋該代碼:

1. 定期評(píng)估條件

if i_episode % config.eval_interval == 0 and config.eval is True:

這部分代碼用于判斷是否應(yīng)該執(zhí)行智能體的評(píng)估。條件為:

  • i_episode % config.eval_interval == 0 :表示每隔 config.eval_interval 個(gè)訓(xùn)練回合( i_episode 是當(dāng)前回合數(shù))進(jìn)行一次評(píng)估。
  • config.eval is True :確保 eval 設(shè)置為 True ,也就是說(shuō),評(píng)估功能開(kāi)啟。

如果滿足這兩個(gè)條件,代碼將開(kāi)始執(zhí)行評(píng)估操作。

2. 初始化評(píng)估列表

eval_reward_list = []

用于存儲(chǔ)每個(gè)評(píng)估回合(episode)的累計(jì)獎(jiǎng)勵(lì),以便之后計(jì)算平均獎(jiǎng)勵(lì)。

3. 進(jìn)行評(píng)估

for _ in range(config.eval_episodes):

評(píng)估階段將運(yùn)行多個(gè)回合(由 config.eval_episodes 指定的回合數(shù)),以獲得智能體的表現(xiàn)。

3.1 回合初始化
state = env.reset()
episode_reward = []
done = False
  • env.reset() :重置環(huán)境,獲得初始狀態(tài) state 。
  • episode_reward :初始化一個(gè)列表,用于存儲(chǔ)當(dāng)前回合中智能體獲得的所有獎(jiǎng)勵(lì)。
  • done = False :用 done 來(lái)跟蹤當(dāng)前回合是否結(jié)束。
3.2 執(zhí)行智能體動(dòng)作
while not done:
    action = agent.select_action(state, evaluate=True)
    next_state, reward, done, info = env.step(action)
    state = next_state
    episode_reward.append(reward)
  • 動(dòng)作選擇 agent.select_action(state, evaluate=True) 在評(píng)估模式下根據(jù)當(dāng)前狀態(tài) state 選擇動(dòng)作。 evaluate=True 表示該選擇是在評(píng)估模式下,通常意味著探索行為被關(guān)閉(即不進(jìn)行隨機(jī)探索,而是選擇最優(yōu)動(dòng)作)。

  • 環(huán)境反饋 next_state, reward, done, info = env.step(action) 通過(guò)執(zhí)行動(dòng)作 action ,環(huán)境返回下一個(gè)狀態(tài) next_state ,當(dāng)前獎(jiǎng)勵(lì) reward ,回合是否結(jié)束的標(biāo)志 done ,以及附加信息 info 。

  • 狀態(tài)更新 :當(dāng)前狀態(tài)被更新為 next_state ,并將獲得的獎(jiǎng)勵(lì) reward 存儲(chǔ)在 episode_reward 列表中。

循環(huán)持續(xù),直到回合結(jié)束(即 done == True )。

3.3 存儲(chǔ)回合獎(jiǎng)勵(lì)
eval_reward_list.append(sum(episode_reward))

當(dāng)前回合結(jié)束后,累計(jì)獎(jiǎng)勵(lì)( sum(episode_reward) )被添加到 eval_reward_list ,用于后續(xù)計(jì)算平均獎(jiǎng)勵(lì)。

4. 計(jì)算平均獎(jiǎng)勵(lì)

avg_reward = np.average(eval_reward_list)

在所有評(píng)估回合結(jié)束后,計(jì)算 eval_reward_list 的平均值 avg_reward 。這是當(dāng)前評(píng)估階段智能體的表現(xiàn)指標(biāo)。

5. 保存最佳模型

if config.save_checkpoint:
    if avg_reward >= best_reward:
        best_reward = avg_reward
        agent.save_checkpoint(checkpoint_path, 'best')
  • 如果 config.save_checkpoint True ,則表示需要檢查是否保存模型。
  • 通過(guò)判斷 avg_reward 是否超過(guò)了之前的最佳獎(jiǎng)勵(lì) best_reward ,如果是,則更新 best_reward ,并保存當(dāng)前模型的檢查點(diǎn)。
agent.save_checkpoint(checkpoint_path, 'best')

這行代碼會(huì)將智能體的狀態(tài)保存到指定的路徑 checkpoint_path ,并標(biāo)記為 "best" ,表示這是性能最佳的模型。

論文復(fù)現(xiàn)結(jié)果

咳咳,可以發(fā)現(xiàn)程序只記錄了 0~1000 的數(shù)據(jù),從 1001 開(kāi)始的每一個(gè)數(shù)據(jù)都顯示報(bào)錯(cuò)所以被舍棄掉了。

后面重新下載了github代碼包,發(fā)生了同樣的報(bào)錯(cuò)信息

報(bào)錯(cuò)信息是:你在 X+1 輪次中嘗試記載 X 輪次中的信息,所以這個(gè)數(shù)據(jù)被舍棄掉了

大概是主程序哪里有問(wèn)題吧,我自己也沒(méi)調(diào) bug

不過(guò)這個(gè)項(xiàng)目結(jié)題了,主要負(fù)責(zé)這個(gè)項(xiàng)目的博士師兄也畢業(yè)了,也不好說(shuō)些什么(雖然我有他微信),至少論文里面的模塊挺有用的。ㄊ謩(dòng)滑稽)

小編推薦閱讀

好特網(wǎng)發(fā)布此文僅為傳遞信息,不代表好特網(wǎng)認(rèn)同期限觀點(diǎn)或證實(shí)其描述。

a 1.0
a 1.0
類型:休閑益智  運(yùn)營(yíng)狀態(tài):正式運(yùn)營(yíng)  語(yǔ)言:中文   

游戲攻略

游戲禮包

游戲視頻

游戲下載

游戲活動(dòng)

《alittletotheleft》官網(wǎng)正版是一款備受歡迎的休閑益智整理游戲。玩家的任務(wù)是對(duì)日常生活中的各種雜亂物
AWA 1.40
AWA 1.40
類型:休閑益智  運(yùn)營(yíng)狀態(tài):未知  語(yǔ)言:中文   

游戲攻略

游戲禮包

游戲視頻

游戲下載

游戲活動(dòng)

《AWA》安卓版是由開(kāi)發(fā)商MentalLab研發(fā)的一款帶有奇幻色彩的神秘迷宮冒險(xiǎn)游戲,華麗而精美的游戲界面,讓

相關(guān)視頻攻略

更多

掃二維碼進(jìn)入好特網(wǎng)手機(jī)版本!

掃二維碼進(jìn)入好特網(wǎng)微信公眾號(hào)!

本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權(quán),請(qǐng)發(fā)郵件[email protected]

湘ICP備2022002427號(hào)-10 湘公網(wǎng)安備:43070202000427號(hào)© 2013~2025 haote.com 好特網(wǎng)