ローリングコンバットピッチなう!

AIとか仮想化とかペーパークラフトとか

強化学習の勉強開始

[technology]Q-LearningによるCartpole(from Open AI Gym)の学習状況を可視化する

Jubatusでアノマリ検知を試すとか言っていたのですが、チュートリアルは普通に動きそうだし、自前で試すのにちょうど良いデータがなかなか見つからないので、Jubatusは放置中です。ちゃんとやりたい事とデータが揃えば有用なことは判ったので、しばらくは引き出しにしまっておきます。(笑)

で、専門でやっている人たちは今更なんでしょうが、時流に乗って強化学習の勉強を始めました。
最初はDeep Q-Learningやるぞ〜!!みたいな感じで居たのですが、色々ググッてみてもイマイチ良く判らず....色々な解説をあたっていくうちに、Deepの前に単なるQ-Learningというものがあるのことを知りました。

ネットで様々な記事を見てなんとなく、雰囲気は判ってきたのですが、とりあえずサンプル付で解りやすい記事が下記にありました。

deepage.net

OpenAIというキーワードも以前から気になってましたが、この記事で謎が解けた感じ。
とりあえずほぼ上記の記事の解説の通りにpythonでコード書いたのですが、学習の様子を可視化したいと思って、サンプルコードに手を加えました。

ソース全体は下記の通りです。デバッグ用に埋め込んだprint文とかもコメント化したまま残しています。

# -*- Coding: utf-8 -*-

import gym
import numpy as np
import time
import matplotlib as mpl
mpl.use('TkAgg')
import matplotlib.pyplot as plt

env = gym.make('CartPole-v0')

goal_average_steps = 195
max_number_of_steps = 200
num_consecutive_iterations = 100
num_episodes = 2000
#num_episodes = 100
last_time_steps = np.zeros(num_consecutive_iterations)
t_hist = np.zeros(num_episodes)

q_table = np.random.uniform(low=-1,high=1,size=(4 ** 4,env.action_space.n))

def bins(clip_min,clip_max,num):
    return np.linspace(clip_min,clip_max,num + 1)[1:-1]

def digitize_state(observation):
    # Convert each state value to descrete expression
    cart_pos,cart_v,pole_angle,pole_v = observation
    digitized = [np.digitize(cart_pos,bins = bins(-2.4,2.4,4)),
                 np.digitize(cart_v,bins=bins(-3.0,3.0,4)),
                 np.digitize(pole_angle,bins=bins(-0.5,0.5,4)),
                 np.digitize(pole_v,bins=bins(-2.0,2.0,4))]
    # Convert to 0-255
    return sum([x * (4 ** i) for i, x in enumerate(digitized)])

def get_action(state,action,observation,reward,episode):
    # Obtain next action from Q-table
    next_state = digitize_state(observation)

#    epsilon = 0.2
    epsilon = 0.5 * (0.99 ** episode)
    if epsilon <= np.random.uniform(0,1):
        next_action = np.argmax(q_table[next_state])
    else:
        next_action = np.random.choice([0,1])

    # Update Q-Table
    alpha = 0.2
    gamma = 0.99
    q_table[state,action] = (1 - alpha) * q_table[state,action] + alpha * (reward + gamma * q_table[next_state,next_action])

    return next_action,next_state



for episode in range(num_episodes):
    # Init Environment
    observation = env.reset()

    state = digitize_state(observation)
    action = np.argmax(q_table[state])

    episode_reward = 0
    for t in range(max_number_of_steps):
        # Draw CartPole
        env.render()

 #       # Select action randomly
 #       action = np.random.choice([0,1])

        # Execute action and get feedback
        observation, reward, done, info = env.step(action)

        # Add penalty if fails
        if done:
            reward = -200
        else:
            episode_reward += reward

        # Select next action
        action,state = get_action(state,action,observation,reward,episode)
#        episode_reward += reward

        print('episode %d:time %d:action %d:reward %d:episode_reward %d' % (episode,t,action,reward,episode_reward))

        if done:
            last_time_steps = np.hstack((last_time_steps[1:],[t]))
            print('%d Episode finished after %f times steps / mean %f' % (episode,t + 1,last_time_steps.mean()))
#            last_time_steps = np.hstack((last_time_steps[1:],[episode_reward]))
#            last_time_steps = np.hstack((last_time_steps[1:],[t]))

            t_hist[episode] = t
            break
#        time.sleep(0.1)
    
    if (last_time_steps.mean() >= goal_average_steps):
        print('Episode %d train agent successfully!' % episode)
        t_hist[episode] = t
#        break

plt.plot(np.linspace(0,episode - 1,episode),t_hist[0:episode])
plt.show()

exit()

やっていることは、各episode毎の何ステップポールを倒さずに続いたかのタイムステップ数を記録して、matplotlibでグラフに描き出します。
結果は以下の様な感じで、1100超えたあたりから一気に毎episode 200タイムステップ連続成功する様になります。
逆に1000episodeくらいのところでも結構100タイムステップ持ちこたえずに失敗しているところがあるので、なんとなく人間が自転車の練習をしていて、ある日突然乗れる様になってそこからは転ばなくなる感じと似ていますね。

オリジナルのコードからの修正点はざっくりと下記の通りです。

  • matplotlib import

    単純に

    import matplotlib.pyplot as plt

    とやると自分の環境(Ubuntu 16.04 LTS + Anaconda3上のpython3.6環境)ではSegmentation Faultが起こるので、

    import matplotlib as mpl
    mpl.use('TkAgg')
    import matplotlib.pyplot as plt

    と、おまじないをしています。

  • 成功判定の変更

    オリジナルコードでは各episode内でdoneになった時に直近100episode分のepisode_rewardの平均値が目標(195steps)を超えた事を判定し、学習終了としています。しかし失敗時にrewardにペナルティ(-200)を設定するコードのためにこれが上手く動いていない様です。
    単純に各episodeの終了時のtime step (t)の直近100episodeの平均値を見る様にしました。
  • 各episodeのtime step数の履歴を記録

    t_histというnumpyのarrayを5000要素分定義し、各episodeの終了時のtをすべて記録しました。
  • 最後にt_histをmatplotlibで描画

    下記のコードで0〜4999までのepisode番号を横軸に、各episodeのtime step数を縦軸にグラフを描画します。

    plt.plot(np.linspace(0,episode - 1,episode),t_hist[0:episode])
    plt.show()

なお以下のコードがコメントアウトされて入っていますが、CPUの全力で回すとcartpoleの動きが早すぎて動きが追えないため、time step毎にsleepを入れるコードです。
コメントアウトすると目視でcartpoleの動きが終えるくらいの速度になります。ただし5000episodeをこれでやる切るのは辛いので、あくまでも描いたコードが思った様に動いているかの確認用です。

#        time.sleep(0.1)

上記の記事は非常に判りやすいのですが、やはり1行1行写経しつつ、特にQ-Tableの更新の様子を自分の頭でなぞりながら動かさないと、なかなかこの学習の根幹のアイデアが理解出来ないように思います。Q-Table自体の更新の様子を可視化できないか悩んでいるのですが、256stateの各actionの評価値をただ並べてもあまりわかりやすくないので、ポールの角度だけに着目してactionの評価値の平均値を途中のepisodeで比べるとか...何か良い可視化方法が無いか考え中です。

とりあえず次は以下の記事で勉強に取り組みます。Deep Q-Learningにはいつ辿りつけるかな?
deepage.net
追記: 上記記事にほぼ同じコードで学習状況可視化が書かれていた。まあ、自分で考えてやったことに無駄は無いでしょう....

良い本無いかな?と思ってAmazonの書評とか見ていますが、なかなか理論と実践のバランスの取れた本って無いですよね。
自分はコードで覚える派なので、基本的な考え方とサンプルが示されていれば、あとは試行錯誤でなんとなく感覚的に理解するんですが、ちゃんと理論も勉強した方が良いのかな?