|
TECH is the new sexy 編者按:本文節(jié)選自《深度學(xué)習(xí)理論與實(shí)戰(zhàn):提高篇 》一書(shū),原文鏈接http://fancyerii./2019/03/14/dl-book/ 。作者李理,環(huán)信人工智能研發(fā)中心vp,有十多年自然語(yǔ)言處理和人工智能研發(fā)經(jīng)驗(yàn),主持研發(fā)過(guò)多款智能硬件的問(wèn)答和對(duì)話系統(tǒng),負(fù)責(zé)環(huán)信中文語(yǔ)義分析開(kāi)放平臺(tái)和環(huán)信智能機(jī)器人的設(shè)計(jì)與研發(fā)。 以下為正文。 本文介紹時(shí)間差分(Temporal Difference)方法。會(huì)分別介紹On-Policy的SARSA算法和Off-Policy的Q-Learning算法。因?yàn)镺ff-Policy可以高效的利用以前的Episode數(shù)據(jù),所以后者在深度強(qiáng)化學(xué)習(xí)中被得到廣泛使用。我們會(huì)通過(guò)一個(gè)Windy GridWorld的簡(jiǎn)單游戲介紹這兩種算法的實(shí)現(xiàn)。 時(shí)間差分是一種非常重要的強(qiáng)化學(xué)習(xí)方法,它結(jié)合了動(dòng)態(tài)規(guī)劃和蒙特卡羅方法的優(yōu)點(diǎn)。 時(shí)間差分預(yù)測(cè)(TD Prediction) 我們首先回顧一下MC的增量更新公式: 這個(gè)和之前Q(s,a)的更新公式稍微有一些區(qū)別,這里的α是個(gè)常量,而在之前的更新公式是一個(gè)不斷變化的量。但其基本思路是一致的,我們“期望”的V(St)是GtGt,因此[Gt?V(St)]可以認(rèn)為是現(xiàn)在“估計(jì)”和實(shí)際值的“誤差”,再乘以一個(gè)較小的數(shù)字α。這有點(diǎn)像梯度下降,如果誤差為零,那么就沒(méi)有變化,如果誤差越大,則V的更新也越多。這個(gè)算法就叫constant-α MC。 前面也討論過(guò)了,蒙特卡羅方法的缺點(diǎn)是Gt只有在Episode結(jié)束后才能計(jì)算出來(lái)。接下來(lái)我們介紹的TD(0)方法能夠解決這個(gè)問(wèn)題,首先我們來(lái)看它的更新公式: 從公式來(lái)看,時(shí)刻t的狀態(tài)St的更新不需要等到Episode結(jié)束,只需要等到下一個(gè)時(shí)刻t+1。蒙特卡羅方法的更新目標(biāo)(Update Target)是Gt,而TD(0)的目標(biāo)是Rt+1+γV(St+1)。因?yàn)門(mén)D(0)更新一個(gè)狀態(tài)的價(jià)值函數(shù)時(shí)需要依賴另外一個(gè)狀態(tài)的價(jià)值,所以它是bootstrapping的方法。 大體來(lái)說(shuō),蒙特卡羅方法使用第一個(gè)等式的估計(jì)(采樣平均)作為更新目標(biāo),而動(dòng)態(tài)規(guī)劃使用第三個(gè)公式作為更新目標(biāo)。而TD(0)使用第三個(gè)公式的估計(jì)(采樣),同時(shí)還用當(dāng)前的V(St+1)來(lái)近似vπSt+1。因此TD(0)可以認(rèn)為是結(jié)合了蒙特卡羅采樣和bootstrapping,bootstrapping是用估計(jì)來(lái)更新估計(jì)。根據(jù)前面的分析,TD(0)的誤差為: 我們來(lái)看一下TD(0)和MC的關(guān)系,MC方法在Episode結(jié)束之前是不會(huì)改變V(s)的,但是 TD(0)會(huì)在t+1時(shí)刻更新St。為了便于分析,我們暫時(shí)假設(shè)直到Episode結(jié)束才統(tǒng)一更新。 Driving Home例子為了比較MC和TD方法,我來(lái)看一個(gè)例子。我們需要估計(jì)開(kāi)車回家要花的時(shí)間,當(dāng)離開(kāi)辦公室的時(shí)候,我們會(huì)注意現(xiàn)在的時(shí)間,今天是星期幾,今天的天氣怎么樣,綜合考慮所有可能影響交通的因素。比如今天是星期五,現(xiàn)在是下午6點(diǎn),根據(jù)以往的經(jīng)驗(yàn),我們估計(jì)可能需要花30分鐘。當(dāng)走到車前時(shí)已經(jīng)是6:05了,我們發(fā)現(xiàn)開(kāi)始下雨了,因?yàn)橄掠晏旖煌〞?huì)變壞,我們重新估計(jì)的結(jié)果是我們還需要35分鐘才能到家,因此此時(shí)估計(jì)總的到家時(shí)間是40分鐘。15分鐘后我們下了高速,這比預(yù)計(jì)的時(shí)間要短,因此我們重新估計(jì)總的時(shí)間為35分鐘。不過(guò)很不幸,前面有個(gè)大貨車,路又很窄超不了車,6:40才到底小區(qū)的路口。這已經(jīng)花了40分鐘了,根據(jù)經(jīng)驗(yàn),3分鐘后就可以到家了,因此我們重新估計(jì)總的時(shí)間是43分鐘后。而3分鐘后,果然我們?nèi)缙诘郊摇?/p> 假設(shè)狀態(tài)就是每一段路的起始和結(jié)束,然后我們需要估計(jì)的是從當(dāng)前狀態(tài)(點(diǎn))到家的時(shí)間(如果是為了找到回家最快的策略,我們可以把Reward定義為花費(fèi)時(shí)間的負(fù)值,不過(guò)這里我們只考慮預(yù)測(cè),因此我們用正值,這樣看起來(lái)簡(jiǎn)單)。我們把每個(gè)狀態(tài)(一段路的開(kāi)始點(diǎn))開(kāi)始的時(shí)間點(diǎn),這段路估計(jì)要花的時(shí)間,總的估計(jì)的(從辦公室)到家的時(shí)間用表格畫(huà)出來(lái): 我們可以把每個(gè)狀態(tài)估計(jì)總的到家時(shí)間(表格的最后一列)畫(huà)出來(lái),如下圖所示。我們先看左圖,這是MC方法的情況,虛線表示每個(gè)狀態(tài)的估計(jì)值和真實(shí)值的差δ。在Episode結(jié)束之前,我們不知道真實(shí)值是多少,只有到家之后,我們知道總共花費(fèi)了43分鐘,那么我們知道了誤差δ,從而可以更新每個(gè)狀態(tài)的V(s)。而右邊是TD的情況,我們不需要等到家,而只需要在下一個(gè)狀態(tài)結(jié)束,我們就能更新前一個(gè)狀態(tài)。比如最初我們估計(jì)要花30分鐘,但到了第二個(gè)狀態(tài)來(lái)到車前時(shí),我們發(fā)現(xiàn)情況有變,我們重新估計(jì)可能要花40分鐘,其實(shí)這個(gè)時(shí)候TD就可以用40-30作為δ來(lái)更新V(S1)了。 圖:Driving Home示例 TD和MC的比較TD相對(duì)于MC最大的優(yōu)點(diǎn)當(dāng)然就是它是online的算法,不用等到Episode結(jié)束就可以更新,因此也就可以用于連續(xù)的任務(wù)。此外,TD通常收斂的更快,當(dāng)然這只是經(jīng)驗(yàn),并沒(méi)有理論的證明。Gt=Rt+1+γRt+2+…+γT?1RT是vπ(St)的無(wú)偏估計(jì),而”真實(shí)“的TD更新目標(biāo)Rt+1+γvπ(St+1)也是無(wú)偏估計(jì),但是因?yàn)槲覀儾⒉恢纕π(St+1)而是用V(St+1)來(lái)近似的,也就是我們實(shí)際用的TD更新目標(biāo)是Rt+1+γV(St+1)。這個(gè)目標(biāo)是有偏的估計(jì)。因?yàn)镸C依賴很多隨機(jī)的Action、隨機(jī)的狀態(tài)跳轉(zhuǎn)和隨機(jī)的reward,所以它的估計(jì)方差較大,而TD值有一次隨機(jī)的Action、隨機(jī)的狀態(tài)跳轉(zhuǎn)和隨機(jī)的reward,因此方差較小。此外MC因?yàn)橛衎ootstrapping,因此它的收斂也依賴于初始值。 接下來(lái)我們分析一下MC和TD優(yōu)化的“目標(biāo)”分別是什么?首先我們來(lái)看一個(gè)例子。假設(shè)這個(gè)MRP(我們不考慮Action)有兩個(gè)狀態(tài)A和B,我們有如下8個(gè)Episode: (A,0,B,0)表示初始化狀態(tài)A,進(jìn)入B,reward是0,然后進(jìn)入B,reward是0,結(jié)束。 圖:AB問(wèn)題 對(duì)于上面的數(shù)據(jù),我們計(jì)算VV(B)=(6?1+2?0)/8=0.75,那么V(A)呢?我們可能有兩種算法,第一種算法如上圖所示:A不是終止?fàn)顟B(tài),A跳轉(zhuǎn)到B的概率是100%,因此V(A)和V(B相同)。而第二種算法是V(A)=0?1/1=0,也就是出現(xiàn)A的Episode一次,最終的Reward是0,因此計(jì)算平均值就是0。第二種方法是MC,它不考慮A和B的關(guān)系,只是看最終的回報(bào)。而第一種是TD,它會(huì)用B的值來(lái)計(jì)算A的值。MC的目標(biāo)函數(shù)是最小均方誤差: 因此對(duì)于上面的例子只有一個(gè)Episode,G=0,所有V(A)=0時(shí)損失是最小0。似乎看起來(lái)最小均方誤差是不錯(cuò)的目標(biāo)函數(shù),那還有沒(méi)有更好的呢?它的問(wèn)題是沒(méi)有利用環(huán)境的馬爾科夫?qū)傩裕鳷D的目標(biāo)函數(shù)就利用了這個(gè)特性,它是從所有的MDP里選擇似然(likelihood)最大的那個(gè)MDP,然后根據(jù)這個(gè)MDP來(lái)計(jì)算最優(yōu)的V(s),也就是它先根據(jù)數(shù)據(jù)估計(jì)出MDP的參數(shù),對(duì)于上面的AB問(wèn)題,它的MDP動(dòng)力系統(tǒng)是:
然后根據(jù)這個(gè)MDP計(jì)算出V(A)=V(B)=0.75。從上面的分析來(lái)說(shuō),如果環(huán)境是MDP的,那么TD會(huì)好一些。 SARSA有了TD(0)來(lái)進(jìn)行策略評(píng)估(預(yù)測(cè)問(wèn)題),接下來(lái)我們就可以用它來(lái)找最優(yōu)策略(控制問(wèn)題)。我們首先介紹On-Policy的算法SARSA。之前我們的TD(0)的更新公式是關(guān)于V(s)的,現(xiàn)在我們首先把它改成Q(s,a)的:
這個(gè)公式更新是需要下一個(gè)t+1時(shí)刻d的St+1和At+1,再加上t時(shí)刻的St和At,以及Rt+1。這五個(gè)字母拼起來(lái)就是SARSA,因此這個(gè)算法就叫SARSA算法。偽代碼如下:
圖:SARSA算法偽代碼 Windy Gridworld環(huán)境介紹在介紹實(shí)現(xiàn)SARSA的代碼之前,我們先來(lái)構(gòu)建一個(gè)Windy Gridworld的環(huán)境,并且會(huì)說(shuō)明為什么這個(gè)問(wèn)題很難用MC來(lái)解決而很容易用TD來(lái)解決。 如下圖所示,和普通的Gridworld不同,每一列的點(diǎn)都有風(fēng),比如第7列和第8列的風(fēng)速都是2,如果我們從第7列采取向右的Action,則它會(huì)向右走一格并且被風(fēng)吹得往上走兩格。圖中的路徑是最優(yōu)的路徑。 這個(gè)環(huán)境用MC方法效果就不好,因?yàn)楹芏郋pisode很長(zhǎng)甚至如果某個(gè)策略不好的話,可能永遠(yuǎn)到底不了終點(diǎn)。而TD方法就能解決這個(gè)問(wèn)題,因?yàn)樗挥玫鹊浇Y(jié)束就可以根據(jù)Reward更新了。
圖:Windy Gridworld 我們先看WindyGridworldEnv,完整代碼在這里。 對(duì)于二維網(wǎng)格這樣的環(huán)境,我們的類可以繼承discrete.DiscreteEnv,然后實(shí)現(xiàn)render方法就行。那step呢?我們需要實(shí)現(xiàn)環(huán)境的動(dòng)力學(xué)P(s′|s,a)。對(duì)于Windy GridWorld來(lái)說(shuō)狀態(tài)是(7,10)的數(shù)組,總共有70個(gè)狀態(tài)。每種狀態(tài)有4個(gè)Action,表示我們讓Agent向上下左右4個(gè)方向移動(dòng)。這個(gè)環(huán)境是確定的,因此對(duì)于每一個(gè)(s,a)的組合,只有一個(gè)s’的概率是1,其余的是0。除了P(s′|s,a),我們還需要知道初始狀態(tài)的概率分布,我們這里很簡(jiǎn)單,它的初始狀態(tài)也是固定的在(3,0),因此在這點(diǎn)的概率是1,而其余點(diǎn)的概率都是0。 這些信息需要通過(guò)調(diào)用父類的構(gòu)造函數(shù)告訴OpenAI Gym,如下所示:
這里nS=70,告訴DiscreteEnv這個(gè)環(huán)境有70個(gè)狀態(tài)。nA=4,表示每個(gè)狀態(tài)都有4種可能的Action——上下左右。 P是一個(gè)dict,key是(0-69),表示每個(gè)狀態(tài)的轉(zhuǎn)移概率。注意:我們用二維數(shù)組表示狀態(tài),但是DiscreteEnv要求狀態(tài)是一維的。我們需要在二維和一維之間進(jìn)行轉(zhuǎn)換,這里會(huì)用到numpy.ravel_multi_index函數(shù)。我們通過(guò)幾個(gè)例子來(lái)學(xué)習(xí)這個(gè)函數(shù):
我們先看第二個(gè)參數(shù)(7,6),它的意思是二維數(shù)組的大小是(7,6)。而輸入是3組二維坐標(biāo)(3,4)、(6,5)和(6,1),默認(rèn)把二維變成一維是類似與C語(yǔ)言的二維數(shù)組——首先是第一行的6個(gè)數(shù),然后是第二行。因此(3,4)對(duì)應(yīng)的一維下標(biāo)是3*6+4=22。 除了C語(yǔ)言的行優(yōu)先(默認(rèn)),還可以類似Fortran語(yǔ)言的列優(yōu)先:
(3,4)表示第3行第4列,因?yàn)槭橇袃?yōu)先,所以4*7+3=31。 兩者的區(qū)別如下:
我們以第二行第七列為例,它的二維坐標(biāo)是(1, 6),對(duì)應(yīng)的一維坐標(biāo)是1*10+6=16。P[16]的內(nèi)容為:
這又是一個(gè)dict,key是4個(gè)Action,0表示UP、1表示RIGHT、2表示DOWN、3表示LEFT。因此上面的例子表示P[16]往上走的概率分布是[(1.0, 6, -1.0, False)],這是一個(gè)數(shù)組。通常P(s′|s,a)是一個(gè)概率,s’可以取很多可能值,但是我們這里只有一個(gè)s’的概率不是零(是1),因此我們的這個(gè)數(shù)組只有一個(gè)元素。1.0表示概率P(6|16,UP)=1,狀態(tài)6轉(zhuǎn)換成二維左邊是(0,6),確實(shí)是在(1,6)的上方。-1.0表示Reward,F(xiàn)alse表示這個(gè)狀態(tài)不是結(jié)束狀態(tài)。 另外如果走一步可能越界,那么就保持原地不動(dòng),比如我們來(lái)看P[0]:
(0, 0)是最左上的點(diǎn),它往上(0)和往左(3)都越界(碰墻),因此還是呆在原地不到。 我們的目標(biāo)點(diǎn)是第四行第八列,坐標(biāo)是(3,7),因此P[48]為:
它表示第五行第九列(狀態(tài)48)往左走一步就進(jìn)入目標(biāo)點(diǎn),因?yàn)樗送笞?,還會(huì)被風(fēng)上吹一步。這是如上圖所示的最右一步。注意:風(fēng)力是離開(kāi)某點(diǎn)起作用的。比如現(xiàn)在在(4, 8),它的風(fēng)力是往上的1;它往左一步就進(jìn)入(4, 7),然后被風(fēng)吹上一格變成(3, 7)。這里的風(fēng)力是(4, 8)點(diǎn)也就是起點(diǎn)的風(fēng)力。有的讀者可能會(huì)以為先走的(4,7)點(diǎn),然后用這點(diǎn)風(fēng)力來(lái)吹,但是這點(diǎn)的風(fēng)力是往上的2,那就會(huì)變成(2, 7),這樣理解是不對(duì)的。 isd(Initial State Distribution)表示初始狀態(tài)的分布,這是一個(gè)長(zhǎng)度為70的數(shù)組,表示初始處于這個(gè)狀態(tài)的概率,我們這里返回的isd只有在下標(biāo)30(對(duì)應(yīng)的二維下標(biāo)是(3,0))是1,其余都是零,也就是初始狀態(tài)總是在第四行第一列。 理解了這些,代碼就很好理解了,render函數(shù)就是把它用圖形(ascii art)的方式展現(xiàn)出來(lái),這里就不贅述了,完整代碼如下:
我們來(lái)嘗試一下這個(gè)環(huán)境,代碼在Windy Gridworld.ipynb。
輸出為:
它對(duì)應(yīng)的就是上圖所示的前4步。 SARSA代碼代碼在SARSA.ipynb,非常簡(jiǎn)單,基本和偽代碼一樣:
對(duì)于windy gridworld任務(wù),我們的超參數(shù)ε=0.1,TD(0)收斂后的策略平均需要17步,比最優(yōu)的15步多2步,原因是它有0.1的概率會(huì)隨機(jī)采取行為。下圖我們繪制是Episode長(zhǎng)度的變化,可以看出,剛開(kāi)始一個(gè)episode很長(zhǎng)(斜率很低),然后隨著迭代快速收斂到一個(gè)最優(yōu)這(斜率不再變化)。
圖:SARSA學(xué)習(xí)過(guò)程 Q-學(xué)習(xí)(Q-Learning)接下來(lái)我們討論一種Off-Policy的TD學(xué)習(xí)算法Q-Learning,這是非常流行的一種算法,后面我們介紹深度學(xué)習(xí)和強(qiáng)化學(xué)習(xí)的結(jié)合時(shí)就會(huì)介紹Deep Q-Learning。我們知道Off-Policy有兩個(gè)策略——目標(biāo)策略和行為策略。對(duì)于Q-Learning來(lái)說(shuō)也是有兩個(gè)策略的,但是和之前的Off-Policy不同,Q-Learning的兩個(gè)策略都是依賴與同一個(gè)Q函數(shù),因此叫做Q-Learning。首先我們看一下怎么把基于重要性采樣的Off-Policy MC算法推廣到基于重要性采樣的Off-Policy TD算法,然后再分析Q-Learning是怎么來(lái)的。 回顧一下前面的內(nèi)容,Off-Policy的MC算法的核心點(diǎn)是用行為策略采樣Episode,但是更新V(s)或者Q(s,a)時(shí)回報(bào)要乘以重要性比例:
V(s)的更新公式是:
如果我們用TD來(lái)代替MC,那么更新公式是:
Q-Learning有兩個(gè)策略,基于Q(s,a)的貪心策略,這是目標(biāo)策略;基于Q(s,a)的ε-貪婪策略,這是行為策略。此外Q-Learning不使用重要性采樣,因此ρρ是1。因此Q-Learning的更新公式是:
在狀態(tài)StSt是的Action使用ε-貪婪的策略,采取行為AtAt之后進(jìn)入狀態(tài)St+1St+1,這個(gè)時(shí)候的A’使用目標(biāo)策略:
把這個(gè)式子代入上式得到Q-Learning的更新目標(biāo):
和SARSA不同,前者兩次行為At和At+1都由ε-貪婪的策略生成,而Q-Learning中,我們用行為策略(ε-貪婪的策略)生成了At,用目標(biāo)策略來(lái)“模擬生成”At+1。 假設(shè)初始化狀態(tài)是S0,SARSA是如下更新的:根據(jù)行為策略生成A0,執(zhí)行此Action,進(jìn)入狀態(tài)S1,然后再更加相同的行為策略生成A1,注意此時(shí)還沒(méi)有執(zhí)行A1,此時(shí)就可以更加SARSA公式更新Q從而更新行為策略了。而如果是Q-Learning,更加行為策略生成A0,執(zhí)行詞Action,進(jìn)入S1,此時(shí)就可以更新Q從更新行為策略了。接著用新的行為策略選擇A1并進(jìn)入S2。 從上面的比較可以看出,對(duì)于SARSA,A0和A1都是有初始化的Q對(duì)于的ε-貪婪策略生成的行為;而對(duì)于Q-Learning,A0是用初始化的策略,而A1已經(jīng)是一個(gè)新的行為策略了。了解了他們的區(qū)別之后,Q-Learning的偽代碼就很簡(jiǎn)單了:
圖:Q-Learning算法 Q-Learning代碼完整代碼在Q-Learning.ipynb。
如下圖所示,運(yùn)行之后我們發(fā)現(xiàn)Q-Learning最后學(xué)到了最優(yōu)的策略——最優(yōu)的步數(shù)15。和On-line的SARSA對(duì)比,Off-Policy策略的Q-Learning能夠?qū)W到最優(yōu)的策略。
圖:Q-Learning學(xué)習(xí)過(guò)程 |
|
|
來(lái)自: 昵稱535749 > 《IT業(yè)與人工智能》