您的位置:首頁(yè) > 軟件教程 > 教程 > 強(qiáng)化學(xué)習(xí)筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】
2024年ICML文章,ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization精讀
該論文是清華項(xiàng)目組組內(nèi)博士師兄寫(xiě)的文章,項(xiàng)目主頁(yè)為 ACE (ace-rl.github.io) ,于2024年7月發(fā)表在ICML期刊
因?yàn)樽罱M內(nèi)(其實(shí)只有我)需要從零開(kāi)始做一個(gè)相關(guān)項(xiàng)目,前面的幾篇文章都是鋪墊
本文章為強(qiáng)化學(xué)習(xí)筆記第5篇
本文初編輯于2024.10.5,好像是這個(gè)時(shí)間,忘記了,前后寫(xiě)了兩個(gè)多星期
CSDN主頁(yè): https://blog.csdn.net/rvdgdsva
博客園主頁(yè): https://www.cnblogs.com/hassle
博客園本文鏈接:
這篇強(qiáng)化學(xué)習(xí)論文主要介紹了一個(gè)名為 ACE 的算法,完整名稱為 Off-Policy Actor-Critic with Causality-Aware Entropy Regularization ,它通過(guò)引入因果關(guān)系分析和因果熵正則化來(lái)解決現(xiàn)有模型在不同動(dòng)作維度上的不平等探索問(wèn)題,旨在改進(jìn)強(qiáng)化學(xué)習(xí)【注釋1】中探索效率和樣本效率的問(wèn)題,特別是在高維度連續(xù)控制任務(wù)中的表現(xiàn)。
【注釋1】: 強(qiáng)化學(xué)習(xí)入門(mén)這一篇就夠了
在policy【注釋2】學(xué)習(xí)過(guò)程中,不同原始行為的不同意義被先前的model-free RL 算法所忽視。利用這一見(jiàn)解,我們探索了不同行動(dòng)維度和獎(jiǎng)勵(lì)之間的因果關(guān)系,以評(píng)估訓(xùn)練過(guò)程中各種原始行為的重要性。我們引入了一個(gè)因果關(guān)系感知熵【注釋3】項(xiàng)(causality-aware entropy term),它可以有效地識(shí)別并優(yōu)先考慮具有高潛在影響的行為,以實(shí)現(xiàn)高效的探索。此外,為了防止過(guò)度關(guān)注特定的原始行為,我們分析了梯度休眠現(xiàn)象(gradientdormancyphenomenon),并引入了休眠引導(dǎo)的重置機(jī)制,以進(jìn)一步增強(qiáng)我們方法的有效性。與無(wú)模型RL基線相比,我們提出的算法 ACE :Off-policy A ctor-criticwith C ausality-aware E ntropyregularization。在跨越7個(gè)域的29種不同連續(xù)控制任務(wù)中顯示出實(shí)質(zhì)性的性能優(yōu)勢(shì),這強(qiáng)調(diào)了我們方法的有效性、多功能性和高效的樣本效率。 基準(zhǔn)測(cè)試結(jié)果和視頻可在https://ace-rl.github.io/上獲得。
【注釋2】: 強(qiáng)化學(xué)習(xí)算法中on-policy和off-policy
【注釋3】: 最大熵 RL:從Soft Q-Learning到SAC - 知乎
【1】 因果關(guān)系分析 :通過(guò)引入因果政策-獎(jiǎng)勵(lì)結(jié)構(gòu)模型,評(píng)估不同動(dòng)作維度(即原始行為)對(duì)獎(jiǎng)勵(lì)的影響大小(稱為“因果權(quán)重”)。這些權(quán)重反映了每個(gè)動(dòng)作維度在不同學(xué)習(xí)階段的相對(duì)重要性。
作出上述改進(jìn)的原因是:考慮一個(gè)簡(jiǎn)單的例子,一個(gè)機(jī)械手最初應(yīng)該學(xué)習(xí)放下手臂并抓住物體,然后將注意力轉(zhuǎn)移到學(xué)習(xí)手臂朝著最終目標(biāo)的運(yùn)動(dòng)方向上。因此,在策略學(xué)習(xí)的不同階段強(qiáng)調(diào)對(duì)最重要的原始行為的探索是 至關(guān)重要的。在探索過(guò)程中刻意關(guān)注各種原始行為,可以加速智能體在每個(gè)階段對(duì)基本原始行為的學(xué)習(xí),從而提高掌握完整運(yùn)動(dòng)任務(wù)的效率。
此處可供學(xué)習(xí)的資料:
【2】 因果熵正則化 :在最大熵強(qiáng)化學(xué)習(xí)框架的基礎(chǔ)上(如SAC算法),加入了 因果加權(quán)的熵正則化項(xiàng) 。與傳統(tǒng)熵正則化不同,這一項(xiàng)根據(jù)各個(gè)原始行為的因果權(quán)重動(dòng)態(tài)調(diào)整,強(qiáng)化對(duì)重要行為的探索,減少對(duì)不重要行為的探索。
作出上述改進(jìn)的原因是:論文引入了一個(gè)因果策略-獎(jiǎng)勵(lì)結(jié)構(gòu)模型來(lái)計(jì)算行動(dòng)空間上的因果權(quán)重(causal weights),因果權(quán)重會(huì)引導(dǎo)agent進(jìn)行更有效的探索, 鼓勵(lì)對(duì)因果權(quán)重較大的動(dòng)作維度進(jìn)行探索,表明對(duì)獎(jiǎng)勵(lì)的重要性更大,并減少對(duì)因果權(quán)重較小的行為維度的探 索。一般的最大熵目標(biāo)缺乏對(duì)不同學(xué)習(xí)階段原始行為之間區(qū)別的重要性的認(rèn)識(shí),可能導(dǎo)致低效的探索。為了解決這一限制,論文引入了一個(gè)由因果權(quán)重加權(quán)的策略熵作為因果關(guān)系感知的熵最大化目標(biāo),有效地加強(qiáng)了對(duì)重要原始行為的探索,并導(dǎo)致了更有效的探索。
此處可供學(xué)習(xí)的資料:
【3】 梯度“休眠”現(xiàn)象(Gradient Dormancy) :論文觀察到,模型訓(xùn)練時(shí)有些梯度會(huì)在某些階段不活躍(即“休眠”)。為了防止模型過(guò)度關(guān)注某些原始行為,論文引入了 梯度休眠導(dǎo)向的重置機(jī)制 。該機(jī)制通過(guò)周期性地對(duì)模型進(jìn)行擾動(dòng)(reset),避免模型陷入局部最優(yōu),促進(jìn)更廣泛的探索。
作出上述改進(jìn)的原因是:該機(jī)制通過(guò)一個(gè)由梯度休眠程度決定的因素間歇性地干擾智能體的神經(jīng)網(wǎng)絡(luò)。將因果關(guān)系感知探索與這種新穎的重置機(jī)制相結(jié)合,旨在促進(jìn)更高效、更有效的探索,最終提高智能體的整體性能。
通過(guò)在多個(gè)連續(xù)控制任務(wù)中的實(shí)驗(yàn),ACE 展示出了顯著優(yōu)于主流強(qiáng)化學(xué)習(xí)算法(如SAC、TD3)的表現(xiàn):
論文中的對(duì)比實(shí)驗(yàn)圖表顯示了 ACE 在多種任務(wù)下的顯著優(yōu)勢(shì),尤其是在 稀疏獎(jiǎng)勵(lì)和高維度任務(wù) 中,ACE 憑借其探索效率的提升,能更快達(dá)到最優(yōu)策略。
在ACE原論文的第21頁(yè),這玩意兒應(yīng)該寫(xiě)在正篇的,害的我看了好久的代碼去排流程
不過(guò)說(shuō)實(shí)話這偽代碼有夠簡(jiǎn)潔的,代碼多少有點(diǎn)糊成一坨了
這是一個(gè)強(qiáng)化學(xué)習(xí)(RL)算法的框架,具體是一個(gè)結(jié)合因果推斷(Causal Discovery)的離策略(Off-policy)Actor-Critic方法。下面是對(duì)每個(gè)模塊及其參數(shù)的說(shuō)明:
源代碼上千行呢,這里只是貼上main_casual里面的部分代碼,并且刪掉了很大一部分代碼以便理清程序脈絡(luò)
def train_loop(config, msg = "default"):
# Agent
agent = ACE_agent(env.observation_space.shape[0], env.action_space, config)
memory = ReplayMemory(config.replay_size, config.seed)
local_buffer = ReplayMemory(config.causal_sample_size, config.seed)
for i_episode in itertools.count(1):
done = False
state = env.reset()
while not done:
if config.start_steps > total_numsteps:
action = env.action_space.sample() # Sample random action
else:
action = agent.select_action(state) # Sample action from policy
if len(memory) > config.batch_size:
for i in range(config.updates_per_step):
#* Update parameters of causal weight
if (total_numsteps % config.causal_sample_interval == 0) and (len(local_buffer)>=config.causal_sample_size):
causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
print("Current Causal Weight is: ",causal_weight)
dormant_metrics = {}
# Update parameters of all the networks
critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight,config.batch_size, updates)
updates += 1
next_state, reward, done, info = env.step(action) # Step
total_numsteps += 1
episode_steps += 1
episode_reward += reward
#* Ignore the "done" signal if it comes from hitting the time horizon.
if '_max_episode_steps' in dir(env):
mask = 1 if episode_steps == env._max_episode_steps else float(not done)
elif 'max_path_length' in dir(env):
mask = 1 if episode_steps == env.max_path_length else float(not done)
else:
mask = 1 if episode_steps == 1000 else float(not done)
memory.push(state, action, reward, next_state, mask) # Append transition to memory
local_buffer.push(state, action, reward, next_state, mask) # Append transition to local_buffer
state = next_state
if total_numsteps > config.num_steps:
break
# test agent
if i_episode % config.eval_interval == 0 and config.eval is True:
eval_reward_list = []
for _ in range(config.eval_episodes):
state = env.reset()
episode_reward = []
done = False
while not done:
action = agent.select_action(state, evaluate=True)
next_state, reward, done, info = env.step(action)
state = next_state
episode_reward.append(reward)
eval_reward_list.append(sum(episode_reward))
avg_reward = np.average(eval_reward_list)
env.close()
初始化 :
config
設(shè)置環(huán)境和隨機(jī)種子。
ACE_agent
初始化強(qiáng)化學(xué)習(xí)智能體,該智能體會(huì)在后續(xù)過(guò)程中學(xué)習(xí)如何在環(huán)境中行動(dòng)。
memory
用于存儲(chǔ)所有的歷史數(shù)據(jù),
local_buffer
則用于因果權(quán)重的更新。
主訓(xùn)練循環(huán) :
采樣動(dòng)作 :如果總步數(shù)較小,則從環(huán)境中隨機(jī)采樣動(dòng)作,否則從策略中選擇動(dòng)作。通過(guò)這種方式,確保早期探索和后期利用。
更新因果權(quán)重
:在特定間隔內(nèi),從局部緩沖區(qū)中采樣數(shù)據(jù),通過(guò)
get_sa2r_weight
函數(shù)使用DirectLiNGAM算法計(jì)算從動(dòng)作到獎(jiǎng)勵(lì)的因果權(quán)重。這個(gè)權(quán)重會(huì)作為額外信息,幫助智能體優(yōu)化策略。
更新網(wǎng)絡(luò)參數(shù)
:當(dāng)
memory
中的數(shù)據(jù)足夠多時(shí),開(kāi)始通過(guò)采樣更新Q網(wǎng)絡(luò)和策略網(wǎng)絡(luò),使用計(jì)算出的因果權(quán)重來(lái)修正損失函數(shù)。
記錄與保存模型 :每隔一定的步數(shù),算法會(huì)測(cè)試當(dāng)前策略的性能,記錄并比較獎(jiǎng)勵(lì)是否超過(guò)歷史最佳值,如果是,則保存模型的檢查點(diǎn)。
使用
wandb
記錄訓(xùn)練過(guò)程中的指標(biāo),例如損失函數(shù)、獎(jiǎng)勵(lì)和因果權(quán)重的計(jì)算時(shí)間,這些信息可以幫助調(diào)試和分析訓(xùn)練過(guò)程。
因果發(fā)現(xiàn)模塊
主要通過(guò)
get_sa2r_weight
函數(shù)實(shí)現(xiàn),并且與
DirectLiNGAM
模型結(jié)合,負(fù)責(zé)計(jì)算因果權(quán)重。具體代碼在訓(xùn)練循環(huán)中如下:
causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
在這個(gè)代碼段,
get_sa2r_weight
函數(shù)會(huì)基于當(dāng)前環(huán)境、樣本數(shù)據(jù)(
local_buffer
)和因果模型(這里使用的是
DirectLiNGAM
),計(jì)算與行動(dòng)相關(guān)的因果權(quán)重(
causal_weight
)。這些權(quán)重會(huì)影響后續(xù)的策略優(yōu)化和參數(shù)更新。關(guān)鍵邏輯包括:
total_numsteps % config.causal_sample_interval == 0
時(shí)觸發(fā),確保只在指定的步數(shù)間隔內(nèi)計(jì)算因果權(quán)重,避免每一步都進(jìn)行因果計(jì)算,減輕計(jì)算負(fù)擔(dān)。
local_buffer
中存儲(chǔ)了足夠的樣本(
config.causal_sample_size
),這些樣本用于因果關(guān)系的發(fā)現(xiàn)。
DirectLiNGAM
是選擇的因果模型,用于從狀態(tài)、行動(dòng)和獎(jiǎng)勵(lì)之間推導(dǎo)出因果關(guān)系。
因果權(quán)重計(jì)算完成后,程序會(huì)將這些權(quán)重應(yīng)用到策略優(yōu)化中,并且記錄權(quán)重及計(jì)算時(shí)間等信息。
def get_sa2r_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
······
return weight, model._running_time
這個(gè)代碼的核心是利用DirectLiNGAM模型計(jì)算給定狀態(tài)、動(dòng)作和獎(jiǎng)勵(lì)之間的因果權(quán)重。接下來(lái),用LaTeX公式詳細(xì)表述計(jì)算因果權(quán)重的過(guò)程:
數(shù)據(jù)預(yù)處理
:
將從
memory
中采樣的
states
(狀態(tài))、
actions
(動(dòng)作)和
rewards
(獎(jiǎng)勵(lì))進(jìn)行拼接,構(gòu)建輸入數(shù)據(jù)矩陣
\(X_{\text{ori}}\)
:
其中, \(S\) 代表狀態(tài), \(A\) 代表動(dòng)作, \(R\) 代表獎(jiǎng)勵(lì)。接著,構(gòu)建數(shù)據(jù)框 \(X\) 來(lái)進(jìn)行因果分析。
因果模型擬合 :
將
X_ori
轉(zhuǎn)換為
X
是為了利用
pandas
數(shù)據(jù)框的便利性和靈活性
使用 DirectLiNGAM 模型對(duì)矩陣 \(X\) 進(jìn)行擬合,得到因果關(guān)系的鄰接矩陣 \(A_{\text{model}}\) :
該鄰接矩陣表示狀態(tài)、動(dòng)作、獎(jiǎng)勵(lì)之間的因果結(jié)構(gòu),特別是從動(dòng)作到獎(jiǎng)勵(lì)的影響關(guān)系。
提取動(dòng)作對(duì)獎(jiǎng)勵(lì)的因果權(quán)重
:
通過(guò)鄰接矩陣提取動(dòng)作對(duì)獎(jiǎng)勵(lì)的因果權(quán)重
\(w_{\text{r}}\)
,該權(quán)重從鄰接矩陣的最后一行中選擇與動(dòng)作對(duì)應(yīng)的元素:
其中, \(d_s\) 是狀態(tài)的維度, \(d_a\) 是動(dòng)作的維度。
因果權(quán)重的歸一化
:
對(duì)因果權(quán)重
\(w_{\text{r}}\)
進(jìn)行Softmax歸一化,確保它們的總和為1:
調(diào)整權(quán)重的尺度
:
最后,因果權(quán)重根據(jù)動(dòng)作的數(shù)量進(jìn)行縮放:
最終輸出的權(quán)重 \(w\) 表示每個(gè)動(dòng)作對(duì)獎(jiǎng)勵(lì)的因果影響,經(jīng)過(guò)歸一化和縮放處理,可以用于進(jìn)一步的策略調(diào)整或分析。
以下是對(duì)函數(shù)工作原理的逐步解釋:
策略優(yōu)化模塊
主要由
agent.update_parameters
函數(shù)實(shí)現(xiàn)。
agent.update_parameters
這個(gè)函數(shù)的主要目的是在強(qiáng)化學(xué)習(xí)中更新策略 (
policy
) 和價(jià)值網(wǎng)絡(luò)(critic)的參數(shù),以提升智能體的性能。這個(gè)函數(shù)實(shí)現(xiàn)了一個(gè)基于軟演員評(píng)論家(SAC, Soft Actor-Critic)的更新機(jī)制,并且加入了因果權(quán)重與"休眠"神經(jīng)元(dormant neurons)的處理,以提高模型的魯棒性和穩(wěn)定性。
critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight, config.batch_size, updates)
通過(guò)
agent.update_parameters
函數(shù),程序會(huì)更新以下幾個(gè)部分:
critic_1_loss
和
critic_2_loss
分別是兩個(gè) Critic 網(wǎng)絡(luò)的損失,用于評(píng)估當(dāng)前策略的價(jià)值。
policy_loss
表示策略網(wǎng)絡(luò)的損失,用于優(yōu)化 agent 的行動(dòng)選擇。
ent_loss
用來(lái)調(diào)節(jié)策略的隨機(jī)性,幫助 agent 在探索和利用之間找到平衡。
這些參數(shù)的更新在每次訓(xùn)練循環(huán)中被調(diào)用,并使用
wandb.log
記錄損失和其他相關(guān)的訓(xùn)練數(shù)據(jù)。
update_parameters
是
ACE_agent
類中的一個(gè)關(guān)鍵函數(shù),用于根據(jù)經(jīng)驗(yàn)回放緩沖區(qū)中的樣本數(shù)據(jù)來(lái)更新模型的參數(shù)。下面是對(duì)其工作原理的詳細(xì)解釋:
首先,函數(shù)從
memory
中采樣一批樣本(
state_batch
、
action_batch
、
reward_batch
、
next_state_batch
、
mask_batch
),其中包括狀態(tài)、動(dòng)作、獎(jiǎng)勵(lì)、下一個(gè)狀態(tài)以及掩碼,用于表示是否為終止?fàn)顟B(tài)。
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
state_batch
:當(dāng)前的狀態(tài)。
action_batch
:在當(dāng)前狀態(tài)下執(zhí)行的動(dòng)作。
reward_batch
:執(zhí)行該動(dòng)作后獲得的獎(jiǎng)勵(lì)。
next_state_batch
:執(zhí)行動(dòng)作后到達(dá)的下一個(gè)狀態(tài)。
mask_batch
:掩碼,用于表示是否為終止?fàn)顟B(tài)(1 表示非終止,0 表示終止)。
利用當(dāng)前策略(policy)網(wǎng)絡(luò),采樣下一個(gè)狀態(tài)的動(dòng)作
next_state_action
和其對(duì)應(yīng)的概率分布對(duì)數(shù)
next_state_log_pi
。然后利用目標(biāo) Q 網(wǎng)絡(luò)
critic_target
估計(jì)下一時(shí)刻的最小 Q 值,并結(jié)合獎(jiǎng)勵(lì)和折扣因子
\(\gamma\)
計(jì)算下一個(gè) Q 值:
with torch.no_grad():
next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch, causal_weight)
qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
通過(guò)策略網(wǎng)絡(luò)
self.policy
為下一個(gè)狀態(tài)
next_state_batch
采樣動(dòng)作
next_state_action
和相應(yīng)的策略熵
next_state_log_pi
。
使用目標(biāo) Q 網(wǎng)絡(luò)計(jì)算
qf1_next_target
和
qf2_next_target
,并取兩者的最小值來(lái)減少估計(jì)偏差。
最終使用貝爾曼方程計(jì)算
next_q_value
,即當(dāng)前的獎(jiǎng)勵(lì)加上折扣因子
\(\gamma\)
乘以下一個(gè)狀態(tài)的 Q 值。
這里,
\(\alpha\)
是熵項(xiàng)的權(quán)重,用于平衡探索和利用的權(quán)衡,而
mask_batch
是為了處理終止?fàn)顟B(tài)的情況。
使用無(wú)偏估計(jì)來(lái)計(jì)算目標(biāo) Q 值。通過(guò)目標(biāo)網(wǎng)絡(luò) (
critic_target
) 計(jì)算出下一個(gè)狀態(tài)和動(dòng)作的 Q 值,并使用獎(jiǎng)勵(lì)和掩碼更新當(dāng)前 Q 值
接著,使用當(dāng)前 Q 網(wǎng)絡(luò)
critic
估計(jì)當(dāng)前狀態(tài)和動(dòng)作下的 Q 值
\(Q_1\)
和
\(Q_2\)
,并計(jì)算它們與目標(biāo) Q 值的均方誤差損失:
最終 Q 網(wǎng)絡(luò)的總損失是兩個(gè) Q 網(wǎng)絡(luò)損失之和:
然后,通過(guò)反向傳播
qf_loss
來(lái)更新 Q 網(wǎng)絡(luò)的參數(shù)。
qf1, qf2 = self.critic(state_batch, action_batch)
qf1_loss = F.mse_loss(qf1, next_q_value)
qf2_loss = F.mse_loss(qf2, next_q_value)
qf_loss = qf1_loss + qf2_loss
self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
qf1
和
qf2
是兩個(gè) Q 網(wǎng)絡(luò)的輸出,用于減少正向估計(jì)偏差。
qf1_loss
和
qf2_loss
分別計(jì)算兩個(gè) Q 網(wǎng)絡(luò)的誤差,最后將兩者相加為總的 Q 損失
qf_loss
。
self.critic_optim
優(yōu)化器對(duì)損失進(jìn)行反向傳播和參數(shù)更新。
每隔若干步(通過(guò)
target_update_interval
控制),開(kāi)始更新策略網(wǎng)絡(luò)
policy
。首先,重新采樣當(dāng)前狀態(tài)下的策略
\(\pi(a|s)\)
,并計(jì)算 Q 值和熵權(quán)重下的策略損失:
這個(gè)損失通過(guò)反向傳播更新策略網(wǎng)絡(luò)。
if updates % self.target_update_interval == 0:
pi, log_pi, _ = self.policy.sample(state_batch, causal_weight)
qf1_pi, qf2_pi = self.critic(state_batch, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()
self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()
state_batch
進(jìn)行采樣,得到動(dòng)作
pi
及其對(duì)應(yīng)的策略熵
log_pi
。
policy_loss
,即
\(\alpha\)
倍的策略熵減去最小的 Q 值。
self.policy_optim
優(yōu)化器對(duì)策略損失進(jìn)行反向傳播和參數(shù)更新。
如果開(kāi)啟了自動(dòng)熵項(xiàng)調(diào)整(
automatic_entropy_tuning
),則會(huì)進(jìn)一步更新熵項(xiàng)
\(\alpha\)
的損失:
并通過(guò)梯度下降更新 \(\alpha\) 。
如果
automatic_entropy_tuning
為真,則會(huì)更新熵項(xiàng):
if self.automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp()
alpha_tlogs = self.alpha.clone()
else:
alpha_loss = torch.tensor(0.).to(self.device)
alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
alpha_loss
更新
self.alpha
,調(diào)整策略的探索-利用平衡。
qf1_loss
,
qf2_loss
: 兩個(gè) Q 網(wǎng)絡(luò)的損失
policy_loss
: 策略網(wǎng)絡(luò)的損失
alpha_loss
: 熵權(quán)重的損失
alpha_tlogs
: 用于日志記錄的熵權(quán)重
next_q_value
: 平均下一個(gè) Q 值
dormant_metrics
: 休眠神經(jīng)元的相關(guān)度量
重置機(jī)制模塊在代碼中主要體現(xiàn)在
update_parameters
函數(shù)中,并通過(guò)
梯度主導(dǎo)度
(dominant metrics) 和
擾動(dòng)函數(shù)
(perturbation functions) 實(shí)現(xiàn)對(duì)策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò)的重置。
函數(shù)根據(jù)設(shè)定的
reset_interval
判斷是否需要對(duì)策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò)進(jìn)行擾動(dòng)和重置。這里使用了"休眠"神經(jīng)元的概念,即一些在梯度更新中影響較小的神經(jīng)元,可能會(huì)被調(diào)整或重置。
函數(shù)計(jì)算了休眠度量
dormant_metrics
和因果權(quán)重差異
causal_diff
,通過(guò)擾動(dòng)因子
perturb_factor
來(lái)決定是否對(duì)網(wǎng)絡(luò)進(jìn)行部分或全部的擾動(dòng)與重置。
重置機(jī)制主要由以下部分組成:
在更新策略時(shí),計(jì)算
主導(dǎo)梯度
,即某些特定神經(jīng)元或參數(shù)在更新中主導(dǎo)作用的比率。代碼中通過(guò)調(diào)用
cal_dormant_grad(self.policy, type='policy', percentage=0.05)
實(shí)現(xiàn)這一計(jì)算,代表提取出 5% 的主導(dǎo)梯度來(lái)作為判斷因子。
dormant_metrics = cal_dormant_grad(self.policy, type='policy', percentage=0.05)
根據(jù)主導(dǎo)度 ($ \beta_\gamma$ ) 和權(quán)重 ($ w$ ),可以得到因果效應(yīng)的差異。代碼里用
causal_diff
來(lái)表示因果差異:
軟重置機(jī)制通過(guò)平滑更新策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò),避免過(guò)大的權(quán)重更新導(dǎo)致的網(wǎng)絡(luò)不穩(wěn)定。這在代碼中由
soft_update
實(shí)現(xiàn):
soft_update(self.critic_target, self.critic, self.tau)
具體來(lái)說(shuō),軟更新的公式為:
其中,( \(\tau\) ) 是一個(gè)較小的常數(shù),通常介于 ( [0, 1] ) 之間,確保目標(biāo)網(wǎng)絡(luò)的更新是緩慢的,以提高學(xué)習(xí)的穩(wěn)定性。
每當(dāng)經(jīng)過(guò)一定的重置間隔時(shí),判斷是否需要擾動(dòng)策略和 Q 網(wǎng)絡(luò)。通過(guò)調(diào)用
perturb()
和
dormant_perturb()
實(shí)現(xiàn)對(duì)網(wǎng)絡(luò)的擾動(dòng)(perturbation)。擾動(dòng)因子由梯度主導(dǎo)度和因果差異共同決定。
策略與 Q 網(wǎng)絡(luò)的擾動(dòng)會(huì)在以下兩種情況下發(fā)生:
代碼中每當(dāng)更新次數(shù)
updates
達(dá)到設(shè)定的重置間隔
self.reset_interval
,并且
updates > 5000
時(shí),才會(huì)觸發(fā)策略與 Q 網(wǎng)絡(luò)的重置邏輯。這是為了確保擾動(dòng)不是頻繁發(fā)生,而是在經(jīng)過(guò)一段較長(zhǎng)的訓(xùn)練時(shí)間后進(jìn)行。
具體判斷條件:
if updates % self.reset_interval == 0 and updates > 5000:
在達(dá)到了重置間隔后,首先會(huì)計(jì)算
梯度主導(dǎo)度
或
因果效應(yīng)的差異
。這可以通過(guò)計(jì)算因果差異
causal_diff
或梯度主導(dǎo)度
dormant_metrics['policy_grad_dormant_ratio']
來(lái)決定是否需要擾動(dòng)。
梯度主導(dǎo)度
計(jì)算方式通過(guò)
cal_dormant_grad()
函數(shù)實(shí)現(xiàn),如果梯度主導(dǎo)度較低,意味著網(wǎng)絡(luò)中的某些神經(jīng)元更新幅度過(guò)小,則需要對(duì)網(wǎng)絡(luò)進(jìn)行擾動(dòng)。
因果效應(yīng)差異
通過(guò)計(jì)算
causal_diff = np.max(causal_weight) - np.min(causal_weight)
得到,如果差異過(guò)大,則可能需要重置。
然后根據(jù)這些值通過(guò)擾動(dòng)因子
factor
進(jìn)行判斷:
factor = perturb_factor(dormant_metrics['policy_grad_dormant_ratio'])
如果擾動(dòng)因子 ( \(\text{factor} < 1\) ),網(wǎng)絡(luò)會(huì)進(jìn)行擾動(dòng):
if factor < 1:
if self.reset == 'reset' or self.reset == 'causal_reset':
perturb(self.policy, self.policy_optim, factor)
perturb(self.critic, self.critic_optim, factor)
perturb(self.critic_target, self.critic_optim, factor)
updates > 5000
)。
這兩種條件同時(shí)滿足時(shí),策略和 Q 網(wǎng)絡(luò)將被擾動(dòng)或重置。
在這段代碼中,
factor
是基于網(wǎng)絡(luò)中梯度主導(dǎo)度或者因果效應(yīng)差異計(jì)算出來(lái)的擾動(dòng)因子。擾動(dòng)因子通過(guò)函數(shù)
perturb_factor()
進(jìn)行計(jì)算,該函數(shù)會(huì)根據(jù)神經(jīng)元的梯度主導(dǎo)度(
dormant_ratio
)或因果效應(yīng)差異(
causal_diff
)來(lái)調(diào)整
factor
的大小。
擾動(dòng)因子
factor
的計(jì)算公式如下:
其中:
( \(\text{dormant\_ratio}\) ) 是網(wǎng)絡(luò)中梯度主導(dǎo)度,即表示有多少神經(jīng)元的梯度變化較。ɑ蛘呓咏悖幱凇靶菝摺睜顟B(tài)。
(
\(\text{min\_perturb\_factor}\)
) 是最小擾動(dòng)因子值,代碼中設(shè)定為
0.2
。
(
\(\text{max\_perturb\_factor}\)
) 是最大擾動(dòng)因子值,代碼中設(shè)定為
0.9
。
dormant_ratio :
dormant_ratio
越大,表示越多神經(jīng)元的梯度變化很小,說(shuō)明網(wǎng)絡(luò)更新不充分,需要擾動(dòng)。
max_perturb_factor :
min_perturb_factor :
在計(jì)算因果效應(yīng)的部分,擾動(dòng)因子
factor
還會(huì)根據(jù)因果效應(yīng)差異
causal_diff
來(lái)調(diào)整。
causal_diff
是通過(guò)計(jì)算因果效應(yīng)的最大值與最小值的差異來(lái)獲得的:
計(jì)算出的
causal_diff
會(huì)影響
causal_factor
,并進(jìn)一步對(duì)
factor
進(jìn)行調(diào)整:
最后,如果選擇了因果重置(
causal_reset
),擾動(dòng)因子將使用因果差異計(jì)算出的
causal_factor
進(jìn)行二次調(diào)整:
綜上所述,
factor
的最終值是由梯度主導(dǎo)度或因果效應(yīng)差異來(lái)控制的,當(dāng)休眠神經(jīng)元比例較大或因果效應(yīng)差異較大時(shí),
factor
會(huì)減小,導(dǎo)致網(wǎng)絡(luò)進(jìn)行擾動(dòng)。
這段代碼主要實(shí)現(xiàn)了在強(qiáng)化學(xué)習(xí)(RL)訓(xùn)練過(guò)程中,定期評(píng)估智能體(agent)的性能,并在某些條件下保存最佳模型的檢查點(diǎn)。我們可以分段解釋該代碼:
if i_episode % config.eval_interval == 0 and config.eval is True:
這部分代碼用于判斷是否應(yīng)該執(zhí)行智能體的評(píng)估。條件為:
i_episode % config.eval_interval == 0
:表示每隔
config.eval_interval
個(gè)訓(xùn)練回合(
i_episode
是當(dāng)前回合數(shù))進(jìn)行一次評(píng)估。
config.eval is True
:確保
eval
設(shè)置為
True
,也就是說(shuō),評(píng)估功能開(kāi)啟。
如果滿足這兩個(gè)條件,代碼將開(kāi)始執(zhí)行評(píng)估操作。
eval_reward_list = []
用于存儲(chǔ)每個(gè)評(píng)估回合(episode)的累計(jì)獎(jiǎng)勵(lì),以便之后計(jì)算平均獎(jiǎng)勵(lì)。
for _ in range(config.eval_episodes):
評(píng)估階段將運(yùn)行多個(gè)回合(由
config.eval_episodes
指定的回合數(shù)),以獲得智能體的表現(xiàn)。
state = env.reset()
episode_reward = []
done = False
env.reset()
:重置環(huán)境,獲得初始狀態(tài)
state
。
episode_reward
:初始化一個(gè)列表,用于存儲(chǔ)當(dāng)前回合中智能體獲得的所有獎(jiǎng)勵(lì)。
done = False
:用
done
來(lái)跟蹤當(dāng)前回合是否結(jié)束。
while not done:
action = agent.select_action(state, evaluate=True)
next_state, reward, done, info = env.step(action)
state = next_state
episode_reward.append(reward)
動(dòng)作選擇
:
agent.select_action(state, evaluate=True)
在評(píng)估模式下根據(jù)當(dāng)前狀態(tài)
state
選擇動(dòng)作。
evaluate=True
表示該選擇是在評(píng)估模式下,通常意味著探索行為被關(guān)閉(即不進(jìn)行隨機(jī)探索,而是選擇最優(yōu)動(dòng)作)。
環(huán)境反饋
:
next_state, reward, done, info = env.step(action)
通過(guò)執(zhí)行動(dòng)作
action
,環(huán)境返回下一個(gè)狀態(tài)
next_state
,當(dāng)前獎(jiǎng)勵(lì)
reward
,回合是否結(jié)束的標(biāo)志
done
,以及附加信息
info
。
狀態(tài)更新
:當(dāng)前狀態(tài)被更新為
next_state
,并將獲得的獎(jiǎng)勵(lì)
reward
存儲(chǔ)在
episode_reward
列表中。
循環(huán)持續(xù),直到回合結(jié)束(即
done == True
)。
eval_reward_list.append(sum(episode_reward))
當(dāng)前回合結(jié)束后,累計(jì)獎(jiǎng)勵(lì)(
sum(episode_reward)
)被添加到
eval_reward_list
,用于后續(xù)計(jì)算平均獎(jiǎng)勵(lì)。
avg_reward = np.average(eval_reward_list)
在所有評(píng)估回合結(jié)束后,計(jì)算
eval_reward_list
的平均值
avg_reward
。這是當(dāng)前評(píng)估階段智能體的表現(xiàn)指標(biāo)。
if config.save_checkpoint:
if avg_reward >= best_reward:
best_reward = avg_reward
agent.save_checkpoint(checkpoint_path, 'best')
config.save_checkpoint
為
True
,則表示需要檢查是否保存模型。
avg_reward
是否超過(guò)了之前的最佳獎(jiǎng)勵(lì)
best_reward
,如果是,則更新
best_reward
,并保存當(dāng)前模型的檢查點(diǎn)。
agent.save_checkpoint(checkpoint_path, 'best')
這行代碼會(huì)將智能體的狀態(tài)保存到指定的路徑
checkpoint_path
,并標(biāo)記為
"best"
,表示這是性能最佳的模型。
咳咳,可以發(fā)現(xiàn)程序只記錄了 0~1000 的數(shù)據(jù),從 1001 開(kāi)始的每一個(gè)數(shù)據(jù)都顯示報(bào)錯(cuò)所以被舍棄掉了。
后面重新下載了github代碼包,發(fā)生了同樣的報(bào)錯(cuò)信息
報(bào)錯(cuò)信息是:你在 X+1 輪次中嘗試記載 X 輪次中的信息,所以這個(gè)數(shù)據(jù)被舍棄掉了
大概是主程序哪里有問(wèn)題吧,我自己也沒(méi)調(diào) bug
不過(guò)這個(gè)項(xiàng)目結(jié)題了,主要負(fù)責(zé)這個(gè)項(xiàng)目的博士師兄也畢業(yè)了,也不好說(shuō)些什么(雖然我有他微信),至少論文里面的模塊挺有用的。ㄊ謩(dòng)滑稽)
機(jī)器學(xué)習(xí):神經(jīng)網(wǎng)絡(luò)構(gòu)建(下)
閱讀華為Mate品牌盛典:HarmonyOS NEXT加持下游戲性能得到充分釋放
閱讀實(shí)現(xiàn)對(duì)象集合與DataTable的相互轉(zhuǎn)換
閱讀鴻蒙NEXT元服務(wù):論如何免費(fèi)快速上架作品
閱讀算法與數(shù)據(jù)結(jié)構(gòu) 1 - 模擬
閱讀基于鴻蒙NEXT的血型遺傳計(jì)算器開(kāi)發(fā)案例
閱讀5. Spring Cloud OpenFeign 聲明式 WebService 客戶端的超詳細(xì)使用
閱讀Java代理模式:靜態(tài)代理和動(dòng)態(tài)代理的對(duì)比分析
閱讀Win11筆記本“自動(dòng)管理應(yīng)用的顏色”顯示規(guī)則
閱讀本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權(quán),請(qǐng)發(fā)郵件[email protected]
湘ICP備2022002427號(hào)-10 湘公網(wǎng)安備:43070202000427號(hào)© 2013~2025 haote.com 好特網(wǎng)