[technology]Q-LearningによるCartpole(from Open AI Gym)の学習状況を可視化する
Jubatusでアノマリ検知を試すとか言っていたのですが、チュートリアルは普通に動きそうだし、自前で試すのにちょうど良いデータがなかなか見つからないので、Jubatusは放置中です。ちゃんとやりたい事とデータが揃えば有用なことは判ったので、しばらくは引き出しにしまっておきます。(笑)
で、専門でやっている人たちは今更なんでしょうが、時流に乗って強化学習の勉強を始めました。
最初はDeep Q-Learningやるぞ〜!!みたいな感じで居たのですが、色々ググッてみてもイマイチ良く判らず....色々な解説をあたっていくうちに、Deepの前に単なるQ-Learningというものがあるのことを知りました。
ネットで様々な記事を見てなんとなく、雰囲気は判ってきたのですが、とりあえずサンプル付で解りやすい記事が下記にありました。
OpenAIというキーワードも以前から気になってましたが、この記事で謎が解けた感じ。
とりあえずほぼ上記の記事の解説の通りにpythonでコード書いたのですが、学習の様子を可視化したいと思って、サンプルコードに手を加えました。
ソース全体は下記の通りです。デバッグ用に埋め込んだprint文とかもコメント化したまま残しています。
# -*- Coding: utf-8 -*- import gym import numpy as np import time import matplotlib as mpl mpl.use('TkAgg') import matplotlib.pyplot as plt env = gym.make('CartPole-v0') goal_average_steps = 195 max_number_of_steps = 200 num_consecutive_iterations = 100 num_episodes = 2000 #num_episodes = 100 last_time_steps = np.zeros(num_consecutive_iterations) t_hist = np.zeros(num_episodes) q_table = np.random.uniform(low=-1,high=1,size=(4 ** 4,env.action_space.n)) def bins(clip_min,clip_max,num): return np.linspace(clip_min,clip_max,num + 1)[1:-1] def digitize_state(observation): # Convert each state value to descrete expression cart_pos,cart_v,pole_angle,pole_v = observation digitized = [np.digitize(cart_pos,bins = bins(-2.4,2.4,4)), np.digitize(cart_v,bins=bins(-3.0,3.0,4)), np.digitize(pole_angle,bins=bins(-0.5,0.5,4)), np.digitize(pole_v,bins=bins(-2.0,2.0,4))] # Convert to 0-255 return sum([x * (4 ** i) for i, x in enumerate(digitized)]) def get_action(state,action,observation,reward,episode): # Obtain next action from Q-table next_state = digitize_state(observation) # epsilon = 0.2 epsilon = 0.5 * (0.99 ** episode) if epsilon <= np.random.uniform(0,1): next_action = np.argmax(q_table[next_state]) else: next_action = np.random.choice([0,1]) # Update Q-Table alpha = 0.2 gamma = 0.99 q_table[state,action] = (1 - alpha) * q_table[state,action] + alpha * (reward + gamma * q_table[next_state,next_action]) return next_action,next_state for episode in range(num_episodes): # Init Environment observation = env.reset() state = digitize_state(observation) action = np.argmax(q_table[state]) episode_reward = 0 for t in range(max_number_of_steps): # Draw CartPole env.render() # # Select action randomly # action = np.random.choice([0,1]) # Execute action and get feedback observation, reward, done, info = env.step(action) # Add penalty if fails if done: reward = -200 else: episode_reward += reward # Select next action action,state = get_action(state,action,observation,reward,episode) # episode_reward += reward print('episode %d:time %d:action %d:reward %d:episode_reward %d' % (episode,t,action,reward,episode_reward)) if done: last_time_steps = np.hstack((last_time_steps[1:],[t])) print('%d Episode finished after %f times steps / mean %f' % (episode,t + 1,last_time_steps.mean())) # last_time_steps = np.hstack((last_time_steps[1:],[episode_reward])) # last_time_steps = np.hstack((last_time_steps[1:],[t])) t_hist[episode] = t break # time.sleep(0.1) if (last_time_steps.mean() >= goal_average_steps): print('Episode %d train agent successfully!' % episode) t_hist[episode] = t # break plt.plot(np.linspace(0,episode - 1,episode),t_hist[0:episode]) plt.show() exit()
やっていることは、各episode毎の何ステップポールを倒さずに続いたかのタイムステップ数を記録して、matplotlibでグラフに描き出します。
結果は以下の様な感じで、1100超えたあたりから一気に毎episode 200タイムステップ連続成功する様になります。
逆に1000episodeくらいのところでも結構100タイムステップ持ちこたえずに失敗しているところがあるので、なんとなく人間が自転車の練習をしていて、ある日突然乗れる様になってそこからは転ばなくなる感じと似ていますね。
オリジナルのコードからの修正点はざっくりと下記の通りです。
- matplotlib import
単純にimport matplotlib.pyplot as plt
とやると自分の環境(Ubuntu 16.04 LTS + Anaconda3上のpython3.6環境)ではSegmentation Faultが起こるので、
import matplotlib as mpl mpl.use('TkAgg') import matplotlib.pyplot as plt
と、おまじないをしています。
- 成功判定の変更
オリジナルコードでは各episode内でdoneになった時に直近100episode分のepisode_rewardの平均値が目標(195steps)を超えた事を判定し、学習終了としています。しかし失敗時にrewardにペナルティ(-200)を設定するコードのためにこれが上手く動いていない様です。
単純に各episodeの終了時のtime step (t)の直近100episodeの平均値を見る様にしました。
- 各episodeのtime step数の履歴を記録
t_histというnumpyのarrayを5000要素分定義し、各episodeの終了時のtをすべて記録しました。
- 最後にt_histをmatplotlibで描画
下記のコードで0〜4999までのepisode番号を横軸に、各episodeのtime step数を縦軸にグラフを描画します。plt.plot(np.linspace(0,episode - 1,episode),t_hist[0:episode]) plt.show()
なお以下のコードがコメントアウトされて入っていますが、CPUの全力で回すとcartpoleの動きが早すぎて動きが追えないため、time step毎にsleepを入れるコードです。
コメントアウトすると目視でcartpoleの動きが終えるくらいの速度になります。ただし5000episodeをこれでやる切るのは辛いので、あくまでも描いたコードが思った様に動いているかの確認用です。
# time.sleep(0.1)
上記の記事は非常に判りやすいのですが、やはり1行1行写経しつつ、特にQ-Tableの更新の様子を自分の頭でなぞりながら動かさないと、なかなかこの学習の根幹のアイデアが理解出来ないように思います。Q-Table自体の更新の様子を可視化できないか悩んでいるのですが、256stateの各actionの評価値をただ並べてもあまりわかりやすくないので、ポールの角度だけに着目してactionの評価値の平均値を途中のepisodeで比べるとか...何か良い可視化方法が無いか考え中です。
とりあえず次は以下の記事で勉強に取り組みます。Deep Q-Learningにはいつ辿りつけるかな?
deepage.net
追記: 上記記事にほぼ同じコードで学習状況可視化が書かれていた。まあ、自分で考えてやったことに無駄は無いでしょう....
良い本無いかな?と思ってAmazonの書評とか見ていますが、なかなか理論と実践のバランスの取れた本って無いですよね。
自分はコードで覚える派なので、基本的な考え方とサンプルが示されていれば、あとは試行錯誤でなんとなく感覚的に理解するんですが、ちゃんと理論も勉強した方が良いのかな?