MATHGRAM

主に数学とプログラミング、時々趣味について。

t-SNEの結果をplotlyで3D可視化する

前回のplotlyの記事で実践編は暇あったら書きます的なこと言ったのですが,今回はそれに当たる内容です.

内容量はかなり少なく薄いですが,plotlyの使用例程度に思ってくれると有難いです.

t-SNEとは

t-SNEとは,皆さまご存知の通り次元圧縮の手法ですね.高次元データを人間が認知できる次元まで綺麗に落とし込める手法なので使っている人は多いのではないでしょうか.

今回はplotlyの使い方を重視した記事なので,理論の話はしませんが需要があったらまとめますね.一応参考になる資料をここにまとめておきます.

論文:
https://lvdmaaten.github.io/publications/papers/JMLR_2008.pdf

参考サイト:
理論とか置いといてt-sneをアプリケーションとして使う人は読むべき
高次元のデータを可視化するt-SNEの効果的な使い方 - DeepAge

ざっと理論の概要を知りたい人はこちらで.
t-SNE を用いた次元圧縮方法のご紹介 | ALBERT Official Blog

実践してみる

上のt-SNE を用いた次元圧縮方法のご紹介 | ALBERT Official Blogでやってることをpythonに移植する流れでやりたいと思います.

データセットcoil-20を使用し,sklearnに実装されているtsneを用います.

まずは3D

まずは3次元まで落とし込んでみましょう.

import os
import numpy as np
import cv2
from sklearn.manifold import TSNE
from sklearn import preprocessing

import plotly.offline as offline
import plotly.graph_objs as go
offline.init_notebook_mode()

# 画像の前処理.標準化やらL2正規化やら.
def preprocess_image(path, size):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    resized = cv2.resize(img, (size, size), cv2.INTER_LINEAR).astype("float")
    normalized = cv2.normalize(resized, None, 0.0, 1.0, cv2.NORM_MINMAX)
    timg = normalized.reshape(np.prod(normalized.shape))
    return timg/np.linalg.norm(timg) 

ROOT = "./coil-20-proc"
ls = os.listdir(ROOT)

# 名前からラベルを持って来ます.
obj_ls = [name.split("_")[0] for name in ls]

ALL_IMAGE_PATH = [ROOT+"/"+path for path in ls]

# 全画像に対して前処理する
preprocess_images_as_vecs = [preprocess_image(path, 32) for path in ALL_IMAGE_PATH]

# tsneを実行
tsne = TSNE(
    n_components=3, #ここが削減後の次元数です.
    init='random',
    random_state=101,
    method='barnes_hut',
    n_iter=1000,
    verbose=2
).fit_transform(preprocess_images_as_vecs)

たったこれだけで次元削減できてしまいます.sklearnに感謝です.

さて,現在tsneという変数に次元削減後のarrayが入っているのでコイツをplotlyを用いて可視化してみます.

# 3Dの散布図が作れるScatter3dを使います.
trace1 = go.Scatter3d(
    x=tsne[:,0], # それぞれの次元をx, y, zにセットするだけです.
    y=tsne[:,1],
    z=tsne[:,2],
    mode='markers',
    marker=dict(
        sizemode='diameter',
        color = preprocessing.LabelEncoder().fit_transform(obj_ls),
        colorscale = 'Portland',
        line=dict(color='rgb(255, 255, 255)'),
        opacity=0.9,
        size=2 # ごちゃごちゃしないように小さめに設定するのがオススメです.
    )
)

data=[trace1]
layout=dict(height=700, width=600, title='coil-20 tsne exmaple')
fig=dict(data=data, layout=layout)
offline.iplot(fig, filename='tsne_example')

こんな感じで出力されます.グリグリ動かしてみてください.とても綺麗に分離できていることがわかります.尚,円形にplotされていることに関する考察などは先のブログでされているので是非参考にしてみてください.

2Dもやってみる.

ほぼ上と同じようにtsneを実行し,Scatter2dを用いて可視化してみます.

# tsneには2dまで落とし込んだarrayが入っている想定です.

trace = go.Scatter(
    x=tsne[:,0],
    y=tsne[:,1],
    mode='markers',
    marker=dict(
        sizemode='diameter',
        color = preprocessing.LabelEncoder().fit_transform(obj_ls),
        colorscale = 'Portland',
        line=dict(color='rgb(255, 255, 255)'),
        opacity=0.9,
        size=4
    )
)

data=[trace]
layout=dict(height=800, width=800, title='coil-20 tsne exmaple 2D')
fig=dict(data=data, layout=layout)
offline.iplot(fig, filename='tsne2D_example')

f:id:ket-30:20170705035109p:plain:w500:h500

この世の生データに比べたらかなり綺麗に分かれていますが,若干バラついてる部分も見受けられますね.やはり3Dと2Dでは元のデータに対する説明量が違いますので,この程度の差は出てしまいます.

内容としては以上です.

あとがき

plotlyは最近かなり使っているのですが,せっかく使っているのに実践編として記事にできてない状況になっています・・・.これからは,今回くらいの内容の薄さでもいいやぁって開き直って記事の更新頻度を上げていきたいと思います.

次はディリクレ分布を可視化してみようと思います.よろしければそちらも是非.

以上です.

matplotlib使いづらくない?plotlyで可視化しようよ

pythonでグラフを可視化する時,matplotlib使いづらくないですか?覚えにくいし,毎回ググってる気がします.

あとデザインもダサいので全然好きになれません.(デザインに関してはseabornを使えば綺麗ですが,結局matplotlibで書くことになるので覚えづらいことには変わりないです・・・)

ただしmatplotlib画像の表示には強いです.そういう時は僕も使います.

プロmatplotliberの方がこの記事を見てくださって,「お前は何もわかっていない.こんなに素晴らしくグラフをかけるんじゃ」って言われたら素直に土下座します.煽り気味のタイトルで本当に申し訳ないです.

そんなこんなで今回は,割と覚えやすくてデザインもよく,3Dの作図にも強い可視化ライブラリ,plotlyを紹介します.

いきなりですがplotlyではこんな作図ができます.

Note: おそらくスマホではうまく表示されませんのでPCで確認お願いします!

追記: PCでも数字のメモリが [?] になってしまう時があるようです.ちゃんと表示させたいときは更新などをして対処お願いします.申し訳ありません.

ご覧の通り,マウスホバーで詳細を表示できたり,グラフを動かせたり3Dの作図もかなりいい感じにできるのでみなさんも使ってみてください.

アジェンダ

  1. Usage
  2. Basic Charts
    1. Scatter Charts (散布図)
    2. Line Charts (折れ線グラフ)
    3. Bar Charts (棒グラフ)
  3. Statistical and Seaborn-style Charts
    1. Error Bars (誤差付き折れ線グラフ)
    2. Box Plots (箱ひげ図)
    3. Histograms (ヒストグラム)
    4. 2d Histograms (二次元ヒストグラム)
    5. 2d Density Plots
  4. Scientific Charts
    1. Heatmaps
    2. Dendrograms (階層クラスタ)
  5. 実践(暇できたら適当にやっていきます.)

目次に書いたもの以外でもたくさん機能はあるのですが,あんまり使わなそうだなぁと個人的に思ったものは紹介していません.DocumentにGoです.

また最後に実践編として分析例をいくつか載せていく予定です.実践編は随時追加予定なので自分のやりたい分析と近いものがあれば参考にしてみてください.

最後に注意事項.
以下で表示しているグラフは全て画像なので,動かせません.注意してください

もちろんみなさんのローカル環境ではグリグリ動かせるグラフができますのでご安心を.

あと最後の最後におまけなんですが,atomhydrogenを使えばatom内で分析がゴリゴリできます.

f:id:ket-30:20170527212818p:plain

実は僕jupyterもそこまで好きじゃないんで,同志がいたら使ってみてくださいね.

1. Usage

まずはinstall.pipで簡単にいけます.

$ pip install plotly

この記事を見ている多くの方は既にinstallしていると思いますが,pandasやらjupyterやらまぁその辺は入れといてください.

※ ここから下は,plotlyの基本的な書き方の説明です.コードを読む方が早いって方は読まなくて大丈夫です.

さて肝心の作図の方法ですが,だいたいのグラフは以下の流れで作成できます.

  1. オフラインで動くように設定する.
  2. traceを作成する.
  3. layoutを定義する.
  4. iplot, もしくはplotで作図する.

1個ずつ言葉を確認していきましょう.

まず結構重要なオフラインの設定です.

plotlyはアカウントを作って,サーバー上にグラフを保存することができます.一番最初に掲載した,3Dの動くグラフもサーバー上に保存されているグラフをお借りしているものです.

しかし僕の場合はオフラインで事足ります.というかほとんどのユーザーはオフラインで満足なはず.ですので以下のようにimportしましょう.

import plotly.offline as offline
offline.init_notebook_mode()

ここの仕組みについてはあまり考える必要ないと思います.僕も知りません.

お次はtraceです.

traceは作図で一番重要なデータや作図方法の情報が入ったものです.具体的な例を示します.

# 注意: 色々省いているのでこれだけでは動きません! 
import plotly.offline as offline
import plotly.graph_objs as go

trace = go.Scatter(
    x = np.array(setosa[columns[1]]),
    y = np.array(setosa[columns[2]]),
    name = "setosa",
    mode = "markers",
    marker = dict(size=10, color="rgba(255, 0, 255, 0.5)"))

ここではScatter(散布図)を使いtraceを作っています.見てわかるように,この時点でxyなどにデータを与えています.またmarkerの部分で点の大きさや色なども指定しています.plotlyではこのtraceを基本単位として扱います.

次はlayoutについて.

先ほどはデータ点そのものについて色や大きさなどを指定しました.layoutではグラフのタイトルや軸の名前など,ひとつ粒度の大きい部分のデザインを定義していきます.具体的な例は以下です.

layout = go.Layout(
    title='Iris sepal length-width',
    xaxis=dict(title='sepal legth(cm)'),
    yaxis=dict(title='sepal width(cm)'),
    showlegend=True)

最後にiplotです. iplotplotの違いはjupyter内で作図をするかどうかの違いなので,基本的にiplotを使います.

先ほど作った,tracelayoutを辞書で囲んであげてiplotします.ちなみに辞書で情報を整理されたものを,plotlyではfigureと呼んでいるみたいです.

fig = dict(data=data, layout=layout)
offline.iplot(fig, filename="example")

以上が主な作図方法の流れです.

2. Basic Charts

2.1 Scatter Charts (散布図)

Simple Scatter Plots

散布図です.1個目なのでIrisデータ使いましょう.

まずはラベルなしでplotしてみます. 教師なし学習とかを想像しながら見てくださいね.

グラフはこんな感じになります. f:id:ket-30:20170521153357p:plain

サンプルコード

Usageでも紹介した通り,

  1. traceを作成する.
  2. layoutを定義する.
  3. iplot, もしくはplotで作図.
    の流れです.

Style Scatter Plots

次はラベルつきでplotしてみます. 教師データを意識してください.

f:id:ket-30:20170521160938p:plain

versiclorvirginicaがガッツリ混ざっていますね.3種類に分類するとき,この特徴量だけでは足りないことが見て取れます.

サンプルコード

Usageでも紹介した通り,

  1. traceを作成する.
  2. layoutを定義する.
  3. iplot, もしくはplotで作図.
    の流れです.

2.2 Line Charts (折れ線グラフ)

折れ線グラフは時系列データを扱うときに使います.
基本的には散布図と同様にScatterを使い,modelinesを与えるだけです.

f:id:ket-30:20170521162955p:plain

サンプルコード

Usageでも紹介した通り,

  1. traceを作成する.
  2. layoutを定義する.
  3. iplot, もしくはplotで作図.
    の流れです.

2.3 Bar Charts (棒グラフ)

棒グラフはカテゴリカル分布の作図や,
それぞれのクラスに属するデータ数を可視化するときに使います.

以下の例では,手書き数字データセットのdigitsを用いてそれぞれの数字にいくつのデータがあるか調べています.

f:id:ket-30:20170521172122p:plain

今回の例ではほとんど同数なので問題ありませんが,
データ数に偏りがあった場合は重み付けなどしないといけませんからね.
この棒グラフの可視化も分析には重要なstepです.

サンプルコード

Usageでも紹介した通り,

  1. traceを作成する.
  2. layoutを定義する.
  3. iplot, もしくはplotで作図.
    の流れです.

3. Statistical and Seaborn-style Charts

3.1 Error Bars

Error Barsとは誤差付き折れ線グラフのことです.ここで紹介しているのは厳密に言うとBasic Continuous Error Barsですが,まあ気にしないでください.多分こっちの方が使います.

使いどころはベイズ線形回帰などがパッと思いつくところです.ベイズ線形回帰だと確率が見えないから微妙かな.まぁいつか何かで実践して載せるつもりです.いつかね!

以下の図とサンプルコードは本家のDocumentをoffline ver.に書き換えただけのほぼ同じものです.

f:id:ket-30:20170522235157p:plain

サンプルコード

  1. traceを作成する.
  2. layoutを定義する.
  3. iplot, もしくはplotで作図.
    の流れです.

3.2 Box Plots (箱ひげ図)

有名な図ですけど,自分で作図したことはほとんどないです.

これも本家のDocumentをoffline に書き換えただけです.申し訳ない.

f:id:ket-30:20170523000145p:plain

サンプルコード

  1. traceを作成する.
  2. layoutを定義する.
  3. iplot, もしくはplotで作図.
    の流れです.

3.3 Histograms

きました,定番のヒストグラムです.めっちゃ使います.

棒グラフと似てますけど違いますからねー.

まずは正規分布から適当にデータをサンプリングして最もシンプルなヒストグラムを生成してみましょう.

こんな感じになります.

f:id:ket-30:20170523002714p:plain

若干きもいヒストグラムになっちゃいました.

サンプルコード

import plotly.offline as offline
import plotly.graph_objs as go
offline.init_notebook_mode()

import numpy as np

x = np.random.randn(500)
data = [go.Histogram(x=x)]

offline.iplot(data, filename='basic histogram')

しかし,分析しているときは何かしらのデータを比較していることも多いですよね.

1つのデータごとに1つずつヒストグラムを作るのはダサいです.

ってことで多分こういうグラフの方が一般的に使うのかなと思います.

f:id:ket-30:20170523004712p:plain

若干ずらして表示してくれるので見やすいですね.

サンプルコード

3.4 2d Histgrams

2つのヒストグラムを使って作図します.次の図は2次元正規分布を無理やり離散に書き換えたものと考えるとわかりやすいかもしれません.

f:id:ket-30:20170527174157p:plain

カウントした総数で正規化すればこのままの状態で確率分布になります.

こういうグラフ見てると周辺化したくなってきますよね.

サンプルコード

3.5 2d Density Plots

2D Histogramと似てますが,こちらは連続データを扱うときに使います.

irisデータに含まれるsetosaのsepal lengthとsepal widthを使って分布を確認してみましょう.

f:id:ket-30:20170521202709p:plain

ちょっとデータが少ないですね.しかもこの多次元データは2次元正規分布に従うというより,正の相関を持ったデータっぽいですね.

多次元正規分布の作図として適してないかもしれませんが,まぁこういうことも作図して初めてわかるときもあるよってことで許してください.

サンプルコード

4 Scientific Charts

4.1 Heatmaps

お次はヒートマップです.3種類の変数の関係性を見たいときに使います. Qiitaのこちらの記事がseabornに含まれているわかりやすいデータを用いているので,こちらと同様にグラフを作ってみましょう.

f:id:ket-30:20170527163923p:plain

色の濃さは乗客数を表しているので,乗客数と年月の相関を確認することができます.パッと見ただけで,12月は帰省などで多いのかな?や,1955年付近には何があったのだろう?と分析の目処を立てることができますね.

また模様の違いがはっきり出ている方が,その変数は特徴量として大きな情報を持っていると判断することもできます.つまり特徴量選択の際にも使うことができます.

サンプルコード

4.2 Dendrograms

主に階層的クラスタリングで使うDendrograms,いわゆる系統樹の紹介です.階層的クラスタリングってなんやねんって方は,こちらを参考にして見てください.

ヒートマップと組み合わせて用いることが多いのですが,そこに関してはseabornの方が簡単にできるような気がしてます.とりあえずここで紹介するのは基本的なDendrogramsってことで許してください.

一応階層クラスタリングを簡単に説明すると,それぞれのデータごとに"キョリ"を計算し,近いものから同じグループとして結合していく手法です."キョリ"の計算方法は色々あるので,それは別の記事で書こうかと.できたらリンク貼りますねー.

結果的にはこんな図ができます.

f:id:ket-30:20170527173110p:plain

サンプルコード

系統樹figure_factoryを用いて作図するのですが,Layoutを扱う際に,若干の注意が必要です.コメントで書いておきましたので,そちらを参考にして見てください.

実践編

暇できたら書く

まとめ

少なくともmatplotlibよりは覚えやすいし,デザイン的にかっこいいグラフが作れると僕は思っています.

またplotlyの真髄は3Dの作図なので,3Dグラフのまとめもすぐに書きますね.

以上です.

崩壊型ギブスサンプリングを用いたトピックモデルの実装

こんにちは.

GWですね.最高の実装日和です.

トピックモデル (機械学習プロフェッショナルシリーズ)

トピックモデル (機械学習プロフェッショナルシリーズ)

今回は上の本を読み,p77の実験を追試したので簡単にまとめます.

学習の際には持橋先生のスライドがとても参考になりました.

余談ですがこの本は数式展開がとても丁寧です.ディリクレ分布に慣れてしまえば数式レベル自体も高くなく,とても読みやすいです. 疑似コードも書いてあり実装する際にとても役立ちます.

この記事では理論に関して詳しい説明しませんので,気になっている方がいたら是非購入することをお勧めします.

さてダイレクトマーケティングが済んだので本題に行きましょう.

トピックモデル (Topic Model)

まずは,簡単にトピックモデルの解説をしていきます.

トピックモデルとは文書データの解析手法として提案されたもので,大量の文書データから注目されているトピックを抽出することができます.

またトピックモデルを用いると,文書を構成する単語にトピックを付与することができます.文書そのものではなく単語にトピックを付与することで,ひとつの文書に複数のトピックを与えることが可能になり,より厳密な分析をすることができます.

f:id:ket-30:20170505003147p:plain

上の図はそれぞれの単語にトピックを付与するイメージ図です.
単語の色は,その単語がどのトピックに属しているのかを表しています.図の中では緑,青,赤の三色しかトピックの定義づけをしていませんが,黄色やオレンジも何らかのトピックに属していると考えてください.

一例として1つ目の文書は,最近起こった事件に基づいて僕が適当に作った文書です.
この文書をトピックモデルによって分析した場合,サッカーに関係する単語が多くスポーツのトピックを持っている反面,爆発という単語から事件性を孕んだ文書であるということがわかります.

崩壊型ギブスサンプリング

崩壊型ギブスサンプリングは,トピックモデルを学習させる手法の1つです. 今回の実装もこの手法を使っています.

先に文字の定義をしておきましょう.


D : 文書数

 K : トピック数

N_d : 文書dに含まれる単語数 (文書長)

 V : 全文書で現れる単語の種類数 (語彙数)

 \boldsymbol{W} : 文書集合

 \boldsymbol{w_d} : 文書d

 w_{dn} : 文書dn番目の単語

 N_k : 文書集合全体でトピックkが割り振られた単語数

 N_{dk} : 文書dでトピックkが割り振られた単語数

 N_{kv} : 文書集合全体で語彙vにトピックkが割り振られた単語数

 \theta_{dk} : 文書dでトピックkが割り当てられる確率

 \phi_{kv} : トピックkのとき,語彙vが生成される確率

 z_{dn} : 文書dn番目の単語に付与されたトピック

 \boldsymbol{Z} : トピック集合


この手法の勘所はパラメータの積分消去です. トピック分布集合 \boldsymbol{\Theta}と単語分布集合\boldsymbol{\Phi}を次のように周辺化することができます.


\displaystyle
\iint p( \boldsymbol{W},  \boldsymbol{Z} , \boldsymbol{\Theta}, \boldsymbol{\Phi} \,|\, \alpha, \beta) \,d\boldsymbol{\Theta} \, d\boldsymbol{\Phi}
= p( \boldsymbol{W},  \boldsymbol{Z} \,|\, \alpha, \beta)
\tag{1} \label{1}

ここで \alpha, \betaは事前分布のパラメータを表しています.一様ディリクレ分布を仮定しているので,ベクトルではなくスカラーです.

このようにパラメータを積分消去することで推定するパラメータの数を減らし,より効率的な推定が可能になります.

またギブスサンプリングをするには,サンプリング式が必要です.

文書dn番目の単語がトピック kに分類される確率は,そのトピックを除いたトピック集合\boldsymbol{Z}_{ \backslash dn }  と文書集合 \boldsymbol{W}が与えられたときの条件付き確率


p(  z_{dn} = k \,| \, \boldsymbol{W},  \boldsymbol{Z}_{\backslash dn}, \alpha, \beta) \\\
\propto p(  z_{dn} = k \,| \, \boldsymbol{Z}_{\backslash dn}, \alpha)p(w_{dn} \,| \, \boldsymbol{W}_{\backslash dn}, z_{dn} = k , \boldsymbol{Z}_{\backslash dn}, \beta)

で与えられます.

それぞれの項は式 \eqref{1}の右辺から,ディリクレ分布を用いて計算できます.結果的にサンプリング式は次のように求められます.


\displaystyle
p(  z_{dn} = k \,| \, \boldsymbol{W},  \boldsymbol{Z}_{\backslash dn}, \alpha, \beta) 
\propto (N_{dk \backslash dn} + \alpha)\frac{ N_{kw_{dn} \backslash dn} + \beta }{ N_{k \backslash dn} + \beta V }

ハイパーパラメータ \alpha, \beta不動点反復法で推定することができ,更新式は以下のようになります.ちなみに\Psi(\cdot)はディガンマ関数です.


\displaystyle
\alpha^{ \rm{new} } = \alpha \frac{ \sum_{d=1}^{D} \sum_{k=1}^{K} \Psi(N_{dk} + \alpha) - DK \Psi(\alpha) }{ K\sum_{d=1}^{D} \Psi(N_d + \alpha K) - DK \Psi(\alpha K)} \\\
\\\
\displaystyle
\beta^{ \rm{new} } = \beta \frac{ \sum_{k=1}^{K} \sum_{v=1}^{V} \Psi(N_{kv} + \beta) - KV \Psi(\beta) }{ V\sum_{k=1}^{K} \Psi(N_k + \beta V) - KV \Psi(\beta V)}

これらの式を用いて,

  1. 単語ごとにサンプリング確率を計算し,トピックを付与.
  2. 全ての単語にトピックが振られたら,ハイパーパラメータの更新.

の手順を収束するまで繰り返すことでトピックを自動で抽出していきます.

実装&実験

実験は青本のp77に書いてある方法とほぼ同じ条件で行いました. 言語はpythonです.

  1. 日本語wikipediaから10万文書抽出する.
  2. その中から頻出単語5000語彙を抽出し語彙集合とする.
  3. ランダムに1万文書を選択し,語彙集合に基づいたBOWを作成する.
  4. トピックモデルを用いて,トピックを抽出する.

手順はざっとこんな感じです. 青本ではトピック数を50にして実験した結果が載っていますが,手元の実験ではトピック数を20にして実験しました.

実行時間は一単語ずつ見ているせいか,100epochで12時間程度かかりました.なかなか時間かかってます.多分何も工夫せず実装しているせいもあると思うので,何かしらアドバイスがある方はぜひコメントください.よろしくお願いします.

以下実装したコードの一部です.
全コードはgitに上げているのでそちらを参照願います.

github.com

class TopicModel():

    def __init__(self, BOWs, K=20, V=5000, max_words=2000, ratio=0.9 ,alpha=1.0, beta=1.0):
        self.BOWs = BOWs
        border = int(ratio * self.BOWs.shape[0])

        self.train_BOWs, self.test_BOWs = np.vsplit(self.BOWs, [border])

        self.V = V
        self.K = K

        self.alpha = alpha
        self.beta  = beta

        self.D = self.train_BOWs.shape[0] 
        self.test_D = self.test_BOWs.shape[0]

        self.N_dk = np.zeros([self.D, self.K]) 
        self.N_kv = np.zeros([self.K, self.V]) 
        self.N_k  = np.zeros([self.K, 1]) 

        self.z_dn = np.zeros([self.D, max_words]) - 1 

    def fit(self, epoch=100):

        self.pplx_ls = np.zeros([epoch])

        for e in range(epoch):
            print("Epoch: {}".format(e+1))

            for d, BOW in enumerate(self.train_BOWs):
                sys.stdout.write("\r%d / %d" % (d+1, self.train_BOWs.shape[0]))
                sys.stdout.flush()

                for n, v in enumerate(BOW):
                    if v < 0: break

                    current_topic = int(self.z_dn[d, n])

                    # reset information of d-th BOW
                    if current_topic >= 0:
                        self.N_dk[d, current_topic] -= 1
                        self.N_kv[current_topic, v] -= 1
                        self.N_k[current_topic] -= 1

                    # sampling
                    p_z_dn = self._calc_probability(d, v)
                    new_topic = self._sampling_topic(p_z_dn)
                    self.z_dn[d, n] = new_topic

                    # update counting
                    self.N_dk[d, new_topic] += 1
                    self.N_kv[new_topic, v] += 1
                    self.N_k[new_topic] += 1


            # update α
            numerator = np.sum(digamma(self.N_dk+self.alpha))\
                      - self.D*self.K*digamma(self.alpha)
            denominator = self.K*(np.sum(digamma(np.count_nonzero(self.train_BOWs+1,axis=1)+self.alpha*self.K))\
                        - self.D*digamma(self.alpha*self.K))
            self.alpha *= numerator / denominator

            # update β
            numerator = np.sum(digamma(self.N_kv+self.beta)) - self.K*self.V*digamma(self.beta)
            denominator = self.V*(np.sum(digamma(self.N_k+self.beta*self.V)) - self.K*digamma(self.beta*self.V))
            self.beta *= numerator / denominator

実験結果

最終的に抽出されたトピック例をpandasを使って眺めてみましょう.そのトピックに割り当てられた数が多い順に単語が並んでいます.

f:id:ket-30:20170506143822p:plain

目がチカチカする・・・.

1個ずつ確認してみましょう.

まずはトピック2.

'こと', '数', '的', '関数', 'よう', '値', '定義', 
'集合', '空間', '単位', '計算', 'とき', '場合', 
'もの', '上', 'これ', '元', '次', '方程式', '点'

完全に数学ですね.うまいこと数学というトピックが抽出できています.

もうひとつ,トピック9.

'年', '戦', 'こと', '選手', '大会', '競技', 'チーム', 
'開催', '位', 'リーグ', '試合',  '日本', '人', '優勝', 
'者', 'ため', '野球', 'レース', 'オリンピック', 'プロ'

これは野球トピックが抽出できていますね.いい感じです.

トピックモデルを用いることで意味付けが容易なトピックを抽出できていることが確認できました.

考察

最終的に得られた事前分布のパラメータに注目してみます.

alpha :0.051928737413754714
beta :0.09577638564520818

\alpha\betaも0に近く,かなり小さな値になっていますね.

このパラメータに基づくディリクレ分布の振る舞いを可視化してみましょう. 可視化のコードはこれらの記事をパクりました.
多項分布とディリクレ分布のまとめと可視化 - ★データ解析備忘録★
Visualizing Dirichlet Distributions with Matplotlib
ありがとうございました.

まずはパラメータが(0.1, 0.1, 0.1)の3次元のディリクレ分布をプロットしてみましょう.

f:id:ket-30:20170505150931p:plain

んー,一様ですか?よくわからないんで,パラメータを(0.99, 0.99, 0.99)にしてもう一度プロットしてみます.

f:id:ket-30:20170505151253p:plain

赤の方が値が大きく,青に近づくほど値は小さいです.
ということで,極端に偏った分布になっていたわけですね.

ここから生成されるベクトルのほとんどが, (1, 0, 0) (0, 1, 0) (0, 0, 1)ということになります.

つまり,One-hotに近いベクトルが出現しやすい事前分布になっているわけですね.

事前分布は,"あっちのトピックとこっちのトピックがあり得そう"などという曖昧なものになりにくいことがわかりました.

おまけ

学習中はトピック1とトピック2の上位10単語をプリントさせました.
以下が学習の遷移です.

Epoch1

最初はほぼランダムなので一貫性はありません.

Epoch: 1
7343 / 7343
parameters
alpha :0.8956689597592044
beta :0.7423208433263517
---------------------
    topic1
こと    3802
年     3739
ため    2947
もの    2382
の     1848
場合    1502
数     1359
現在    1185
部     1009
万     1006
---------------------
    topic2
よう  4262
こと  2417
ため  2329
場合  2169
的   2078
部分  1382
これ  1270
もの  1251
関数  1208
漫画  1181
*********************

Epoch50

トピック2に数学系の単語が集まってきています. トピック1はこの段階ではよくわかりませんね・・・.

Epoch: 50
7343 / 7343
parameters
alpha :0.05620034268917989
beta :0.09405920652272919
---------------------
    topic1
こと    4213
もの    3076
よう    2191
ため    1969
日本    1850
的     1527
場合    1449
種     1225
の     1169
これ    1103
---------------------
    topic2
こと  6929
的   3302
数   3086
よう  2869
関数  2530
値   2311
定義  2197
集合  1975
空間  1954
単位  1905
*********************

Epoch150

トピック1にもなんとなく一貫性が出てきてます.
でもトピック1に名前つけるのムズイですねw
ダーウィンが来たとかですか?

Epoch: 150
7343 / 7343
parameters
alpha :0.051243695865978385
beta :0.09653944478736103
---------------------
    topic1
こと    2684
もの    1280
ため    1235
よう    1184
的     1153
年     1036
大陸    1030
動物     999
地球     958
種      942
---------------------
    topic2
こと  6211
数   2977
的   2833
関数  2493
よう  2456
値   2170
定義  2116
集合  1993
空間  1842
単位  1804
*********************

あとがき

本当はトピックモデルをユニグラムモデルから説明する記事を書こうとしてたんですが,途中で挫折してしまいました.

とにかく時間がかかりすぎる・・・.

一応下書きは保存してあるので,機会があれば書き切りたいです.多分書かないですがw

以上です.