読者です 読者をやめる 読者になる 読者になる

MATHGRAM

主に数学とプログラミング、時々趣味について。

崩壊型ギブスサンプリングを用いたトピックモデルの実装

こんにちは.

GWですね.最高の実装日和です.

トピックモデル (機械学習プロフェッショナルシリーズ)

トピックモデル (機械学習プロフェッショナルシリーズ)

今回は上の本を読み,p77の実験を追試したので簡単にまとめます.

学習の際には持橋先生のスライドがとても参考になりました.

余談ですがこの本は数式展開がとても丁寧です.ディリクレ分布に慣れてしまえば数式レベル自体も高くなく,とても読みやすいです. 疑似コードも書いてあり実装する際にとても役立ちます.

この記事では理論に関して詳しい説明しませんので,気になっている方がいたら是非購入することをお勧めします.

さてダイレクトマーケティングが済んだので本題に行きましょう.

トピックモデル (Topic Model)

まずは,簡単にトピックモデルの解説をしていきます.

トピックモデルとは文書データの解析手法として提案されたもので,大量の文書データから注目されているトピックを抽出することができます.

またトピックモデルを用いると,文書を構成する単語にトピックを付与することができます.文書そのものではなく単語にトピックを付与することで,ひとつの文書に複数のトピックを与えることが可能になり,より厳密な分析をすることができます.

f:id:ket-30:20170505003147p:plain

上の図はそれぞれの単語にトピックを付与するイメージ図です.
単語の色は,その単語がどのトピックに属しているのかを表しています.図の中では緑,青,赤の三色しかトピックの定義づけをしていませんが,黄色やオレンジも何らかのトピックに属していると考えてください.

一例として1つ目の文書は,最近起こった事件に基づいて僕が適当に作った文書です.
この文書をトピックモデルによって分析した場合,サッカーに関係する単語が多くスポーツのトピックを持っている反面,爆発という単語から事件性を孕んだ文書であるということがわかります.

崩壊型ギブスサンプリング

崩壊型ギブスサンプリングは,トピックモデルを学習させる手法の1つです. 今回の実装もこの手法を使っています.

先に文字の定義をしておきましょう.


D : 文書数

 K : トピック数

N_d : 文書dに含まれる単語数 (文書長)

 V : 全文書で現れる単語の種類数 (語彙数)

 \boldsymbol{W} : 文書集合

 \boldsymbol{w_d} : 文書d

 w_{dn} : 文書dn番目の単語

 N_k : 文書集合全体でトピックkが割り振られた単語数

 N_{dk} : 文書dでトピックkが割り振られた単語数

 N_{kv} : 文書集合全体で語彙vにトピックkが割り振られた単語数

 \theta_{dk} : 文書dでトピックkが割り当てられる確率

 \phi_{kv} : トピックkのとき,語彙vが生成される確率

 z_{dn} : 文書dn番目の単語に付与されたトピック

 \boldsymbol{Z} : トピック集合


この手法の勘所はパラメータの積分消去です. トピック分布集合 \boldsymbol{\Theta}と単語分布集合\boldsymbol{\Phi}を次のように周辺化することができます.


\displaystyle
\iint p( \boldsymbol{W},  \boldsymbol{Z} , \boldsymbol{\Theta}, \boldsymbol{\Phi} \,|\, \alpha, \beta) \,d\boldsymbol{\Theta} \, d\boldsymbol{\Phi}
= p( \boldsymbol{W},  \boldsymbol{Z} \,|\, \alpha, \beta)
\tag{1} \label{1}

ここで \alpha, \betaは事前分布のパラメータを表しています.一様ディリクレ分布を仮定しているので,ベクトルではなくスカラーです.

このようにパラメータを積分消去することで推定するパラメータの数を減らし,より効率的な推定が可能になります.

またギブスサンプリングをするには,サンプリング式が必要です.

文書dn番目の単語がトピック kに分類される確率は,そのトピックを除いたトピック集合\boldsymbol{Z}_{ \backslash dn }  と文書集合 \boldsymbol{W}が与えられたときの条件付き確率


p(  z_{dn} = k \,| \, \boldsymbol{W},  \boldsymbol{Z}_{\backslash dn}, \alpha, \beta) \\\
\propto p(  z_{dn} = k \,| \, \boldsymbol{Z}_{\backslash dn}, \alpha)p(w_{dn} \,| \, \boldsymbol{W}_{\backslash dn}, z_{dn} = k , \boldsymbol{Z}_{\backslash dn}, \beta)

で与えられます.

それぞれの項は式 \eqref{1}の右辺から,ディリクレ分布を用いて計算できます.結果的にサンプリング式は次のように求められます.


\displaystyle
p(  z_{dn} = k \,| \, \boldsymbol{W},  \boldsymbol{Z}_{\backslash dn}, \alpha, \beta) 
\propto (N_{dk \backslash dn} + \alpha)\frac{ N_{kw_{dn} \backslash dn} + \beta }{ N_{k \backslash dn} + \beta V }

ハイパーパラメータ \alpha, \beta不動点反復法で推定することができ,更新式は以下のようになります.


\displaystyle
\alpha^{ \rm{new} } = \alpha \frac{ \sum_{d=1}^{D} \sum_{k=1}^{K} \Psi(N_{dk} + \alpha) - DK \Psi(\alpha) }{ K\sum_{d=1}^{D} \Psi(N_d + \alpha K) - DK \Psi(\alpha K)} \\\
\\\
\displaystyle
\beta^{ \rm{new} } = \beta \frac{ \sum_{k=1}^{K} \sum_{v=1}^{V} \Psi(N_{kv} + \beta) - KV \Psi(\beta) }{ V\sum_{k=1}^{K} \Psi(N_k + \beta V) - KV \Psi(\beta V)}

これらの式を用いて, 1. 単語ごとにサンプリング確率を計算し,トピックを付与.
1. 全ての単語にトピックが振られたら,ハイパーパラメータの更新.

の手順を収束するまで繰り返すことて. トピックを自動で抽出していきます.

実装&実験

実験は青本のp77に書いてある方法とほぼ同じ条件で行いました. 言語はpythonです.

  1. 日本語wikipediaから10万文書抽出する.
  2. その中から頻出単語5000語彙を抽出し語彙集合とする.
  3. ランダムに1万文書を選択し,語彙集合に基づいたBOWを作成する.
  4. トピックモデルを用いて,トピックを抽出する.

手順はざっとこんな感じです. 青本ではトピック数を50にして実験した結果が載っていますが,手元の実験ではトピック数を20にして実験しました.

実行時間は一単語ずつ見ているせいか,100epochで12時間程度かかりました.なかなか時間かかってます.

以下実装したコードの一部です.
全コードはgitに上げているのでそちらを参照願います.

github.com

class TopicModel():

    def __init__(self, BOWs, K=20, V=5000, max_words=2000, ratio=0.9 ,alpha=1.0, beta=1.0):
        self.BOWs = BOWs
        border = int(ratio * self.BOWs.shape[0])

        self.train_BOWs, self.test_BOWs = np.vsplit(self.BOWs, [border])

        self.V = V
        self.K = K

        self.alpha = alpha
        self.beta  = beta

        self.D = self.train_BOWs.shape[0] 
        self.test_D = self.test_BOWs.shape[0]

        self.N_dk = np.zeros([self.D, self.K]) 
        self.N_kv = np.zeros([self.K, self.V]) 
        self.N_k  = np.zeros([self.K, 1]) 

        self.z_dn = np.zeros([self.D, max_words]) - 1 

    def fit(self, epoch=100):

        self.pplx_ls = np.zeros([epoch])

        for e in range(epoch):
            print("Epoch: {}".format(e+1))

            for d, BOW in enumerate(self.train_BOWs):
                sys.stdout.write("\r%d / %d" % (d+1, self.train_BOWs.shape[0]))
                sys.stdout.flush()

                for n, v in enumerate(BOW):
                    if v < 0: break

                    current_topic = int(self.z_dn[d, n])

                    # reset information of d-th BOW
                    if current_topic >= 0:
                        self.N_dk[d, current_topic] -= 1
                        self.N_kv[current_topic, v] -= 1
                        self.N_k[current_topic] -= 1

                    # sampling
                    p_z_dn = self._calc_probability(d, v)
                    new_topic = self._sampling_topic(p_z_dn)
                    self.z_dn[d, n] = new_topic

                    # update counting
                    self.N_dk[d, new_topic] += 1
                    self.N_kv[new_topic, v] += 1
                    self.N_k[new_topic] += 1


            # update α
            numerator = np.sum(digamma(self.N_dk+self.alpha))\
                      - self.D*self.K*digamma(self.alpha)
            denominator = self.K*(np.sum(digamma(np.count_nonzero(self.train_BOWs+1,axis=1)+self.alpha*self.K))\
                        - self.D*digamma(self.alpha*self.K))
            self.alpha *= numerator / denominator

            # update β
            numerator = np.sum(digamma(self.N_kv+self.beta)) - self.K*self.V*digamma(self.beta)
            denominator = self.V*(np.sum(digamma(self.N_k+self.beta*self.V)) - self.K*digamma(self.beta*self.V))
            self.beta *= numerator / denominator

実験結果

最終的に抽出されたトピック例をpandasを使って眺めてみましょう.そのトピックに割り当てられた数が多い順に単語が並んでいます.

f:id:ket-30:20170506143822p:plain

目がチカチカする・・・.

1個ずつ確認してみましょう.

まずはトピック2.

'こと', '数', '的', '関数', 'よう', '値', '定義', 
'集合', '空間', '単位', '計算', 'とき', '場合', 
'もの', '上', 'これ', '元', '次', '方程式', '点'

完全に数学ですね.うまいこと数学というトピックが抽出できています.

もうひとつ,トピック9.

'年', '戦', 'こと', '選手', '大会', '競技', 'チーム', 
'開催', '位', 'リーグ', '試合',  '日本', '人', '優勝', 
'者', 'ため', '野球', 'レース', 'オリンピック', 'プロ'

これは野球トピックが抽出できていますね.いい感じです.

トピックモデルを用いることで意味付けが容易なトピックを抽出できていることが確認できました.

考察

最終的に得られた事前分布のパラメータに注目してみます.

alpha :0.051928737413754714
beta :0.09577638564520818

\alpha\betaも0に近く,かなり小さな値になっていますね.

このパラメータに基づくディリクレ分布の振る舞いを可視化してみましょう. 可視化のコードはこれらの記事をパクりました.
多項分布とディリクレ分布のまとめと可視化 - ★データ解析備忘録★
Visualizing Dirichlet Distributions with Matplotlib
ありがとうございました.

まずはパラメータが(0.1, 0.1, 0.1)の3次元のディリクレ分布をプロットしてみましょう.

f:id:ket-30:20170505150931p:plain

んー,一様ですか?よくわからないんで,パラメータを(0.99, 0.99, 0.99)にしてもう一度プロットしてみます.

f:id:ket-30:20170505151253p:plain

赤の方が値が大きく,青に近づくほど値は小さいです.
ということで,極端に偏った分布になっていたわけですね.

ここから生成されるベクトルのほとんどが, (1, 0, 0) (0, 1, 0) (0, 0, 1)ということになります.

つまり,One-hotに近いベクトルが出現しやすい事前分布になっているわけですね.

事前分布は,"あっちのトピックとこっちのトピックがあり得そう"などという曖昧なものになりにくいことがわかりました.

おまけ

学習中はトピック1とトピック2の上位10単語をプリントさせました.
以下が学習の遷移です.

Epoch1

最初はほぼランダムなので一貫性はありません.

Epoch: 1
7343 / 7343
parameters
alpha :0.8956689597592044
beta :0.7423208433263517
---------------------
    topic1
こと    3802
年     3739
ため    2947
もの    2382
の     1848
場合    1502
数     1359
現在    1185
部     1009
万     1006
---------------------
    topic2
よう  4262
こと  2417
ため  2329
場合  2169
的   2078
部分  1382
これ  1270
もの  1251
関数  1208
漫画  1181
*********************

Epoch50

トピック2に数学系の単語が集まってきています. トピック1はこの段階ではよくわかりませんね・・・.

Epoch: 50
7343 / 7343
parameters
alpha :0.05620034268917989
beta :0.09405920652272919
---------------------
    topic1
こと    4213
もの    3076
よう    2191
ため    1969
日本    1850
的     1527
場合    1449
種     1225
の     1169
これ    1103
---------------------
    topic2
こと  6929
的   3302
数   3086
よう  2869
関数  2530
値   2311
定義  2197
集合  1975
空間  1954
単位  1905
*********************

Epoch150

トピック1にもなんとなく一貫性が出てきてます.
でもトピック1に名前つけるのムズイですねw
ダーウィンが来たとかですか?

Epoch: 150
7343 / 7343
parameters
alpha :0.051243695865978385
beta :0.09653944478736103
---------------------
    topic1
こと    2684
もの    1280
ため    1235
よう    1184
的     1153
年     1036
大陸    1030
動物     999
地球     958
種      942
---------------------
    topic2
こと  6211
数   2977
的   2833
関数  2493
よう  2456
値   2170
定義  2116
集合  1993
空間  1842
単位  1804
*********************

あとがき

本当はトピックモデルをユニグラムモデルから説明する記事を書こうとしてたんですが,途中で挫折してしまいました.

とにかく時間がかかりすぎる・・・.

一応下書きは保存してあるので,機会があれば書き切りたいです.多分書かないですがw

以上です.