金融と工学のあいだ

興味関心に関するメモ(機械学習、検索エンジン、プログラミングなど)

Chainerによるword2vecのチュートリアル(1)

word2vecとは

導入

  • word2vecは単語の分散表現を生成する手法です。単語の意味が近いほど類似度が大きくなるように、各単語に実ベクトルを割り当てる手法です。
  • そもそも単語の意味とはなんでしょうか。
    • 人であれば動物と犬という単語が似ているというのはなんとなく分かります。
    • しかし、word2vecは何の情報を元に、動物と犬は似ているとか、食べ物と犬は似ていないといった意味の類似度を学習すれば良いのでしょうか。

基本的なアイデア

  • word2vecは単語の意味の類似度を単純な情報から学習します。
  • それは文章における単語の並びです。ある単語の意味は、その単語の周囲の単語で決まるというアイデアです。
  • 学習対象の単語をTarget Word、その周囲の単語をContext Wordと呼びます。ウィンドウサイズcに応じてContex Wordの数は変わります。
  • 例として、The cute cat jumps over the lazy dog. という文で説明を行います。
    • 以下の図は全てTarget Wordをcatとした場合のものです。
    • ウィンドウサイズcに応じて、catを学習する際に使用するContext Wordが変わることがわかると思います。
  • f:id:kumechann:20170628152453p:plain

主なアルゴリズム

  • word2vecと呼ばれる手法は実はSkip-gramとCBoWという2つの手法の総称です。
  • 以下で図を使って説明しますが、その時の記号の意味をここで説明します。
    • N:ボキャブラリ数。
    • D:分散表現のベクトルのサイズ。
    • {v_t} :Target Word。サイズは[N,1]。
    • {v_{t+c}}:Context Word。サイズは[N,1]。
    • {L_H}:隠れ層。サイズは[D,1]。
    • {L_O}:出力層。サイズは[N,1]。
    • {W_H}:隠れ層重み行列。サイズは[N,D]。
    • {W_O}:出力層重み行列。サイズは[D,N]。

Skip-gram

  • Target Word {v_t}が与えられた時、Context Word {v_{t+c}}が出現することを予測するように学習する
  • この時、隠れ層重み行列W_Hの各行がそれぞれの単語の分散表現になる。

  • f:id:kumechann:20170628153848p:plain

Continuous Bag of Words (CBoW)

  • Context Word {v_{t+c}}が与えられた時、Target Word {v_t}が出現することを予測するように学習する
  • この時、出力層重み行列{W_O}の各列がそれぞれの単語の分散表現になる。

  • f:id:kumechann:20170628153915p:plain

Skip-gramの詳細

  • チュートリアルでは、以下の観点からSkip-gramをメインで扱います。
    1. 学習アルゴリズムがCBoWに比べて理解しやすい
    2. 単語数が増えても精度が落ちにくく、スケールしやすい

具体例を使った説明

  • 上の例と同じように、ボキャブラリ数Nは10、分散表現のベクトルのサイズDは2とします。
  • 犬という単語をTarget Word、動物という単語をContext Wordとし学習する様子を説明します。
  • Context Wordは複数あるはずなので下記工程をContext Word分繰り返します。

    1. 犬という単語の局所表現は[0 0 1 0 0 0 0 0 0 0]であり、それをTarget Wordとして入力します。
    2. この時、隠れ層重み行列{W_H}の3行目が隠れ層{L_H}になります。
      • ちなみに、ここの値が学習後には犬という単語の分散表現になります。
    3. 出力層重み行列{W_O}と隠れ層{L_H}をかけた結果が出力層{L_O}となります。
    4. 出力層の各要素の値を制限するために、出力層{L_O}にsoftmax関数を適用し{W_O}を計算する
      • 最終的には出力層とContext Wordの誤差を出し、その誤差をネットワークに逆伝播することでパラメータの更新を行う必要があります。
      • しかし、出力層の各要素の値は範囲が制限されていないため-∞〜+∞までの値をとります。Context Wordの局所表現は、[1 0 0 0 0 0 0 0 0 0]のように各要素は0か1の値しか取りません。
      • 出力層の各要素の値を0〜1に制限するため、各要素の値を0〜1の範囲に制限する関数softmaxを適用します。
    5. {W_O}とanimalの局所表現[1 0 0 0 0 0 0 0 0 0]の誤差を計算し、その誤差をネットワークに逆伝播させてパラメータを更新する
  • f:id:kumechann:20170628155336p:plain

Chainerによる実装方法

実装方法

  • 基本的にchainerを使用する場合にはこのような形でimportをします。
    • functionsをF、linksをLのような形でimportすると使いやすいです。
  • 次にskip-gramのネットワーク構造の定義です。
    • コンストラク__init__に、ボキャブラリ数n_vocab、分散ベクトルのサイズn_units、損失関数loss_funcを渡すようになっています。
      • init_scope()内でParameterの初期化を行っています。
        • ここでParameterの初期化を行うことが推奨されています。
        • PrameterをLinkのattributeとして設定してくれるので、IDEでコードを追いやすくなるなどの効果があります。
        • 詳しくはここ Upgrade Guide — Chainer 6.3.0 documentation
      • ここで、self.embed内の重み行列Wが隠れ層重み行列{W_H}にあたります。
      • 注意してもらいたいのが、Skip-gramの場合、Target WordとContext Wordを1対1で対応するため、Context WordとTarget Wordを入れ替えても問題がなく、Context WordとTarget Wordを入れ替えて学習させています。(CBoWモデルとコードの整合性が取りやすいからです。)
    • 関数呼び出し__call__は、Target WordのIDx、Context WordのIDcontextを受取り損失関数loss_funcで誤差を返します。
      • e = self.embed(context)でcontextに対応する分散表現を取得しています。
      • batch_size分のTarget WordのIDxをContext Wordの数だけbroad castします。
      • x[batch_size * n_context,], e[batch_size * n_context, n_units]になります。
  • 損失関数の定義です。実質的には、skip-gramのネットワーク構造の定義をしています。
    • xに対して重み行列による線形写像self.out(x) (self.out:= L.Linear(n_in, n_out, initialW=0)(x))を計算した後、F.softmax_cross_entropyを計算します。
    • ここで、線形写像self.out(x)は出力層重み行列{W_O}F.softmax_cross_entropyはsoftmax関数と損失計算部分に該当します。
  • Iteratorの定義
    • コンストラク__init__に、単語idのリストによる文書データセットdataset、ウインドウサイズwindow、ミニバッチサイズbatch_sizeを渡すようになっています。
      • この中で、文書中の単語の位置をシャッフルしたリストself.orderを作成しています。それは学習する時に、文書の最初から最後まで順番に学習するのではなく、文書からランダムに単語を選択し学習するようにするためです。ウィンドウサイズ分だけ最初と最後を切り取った単語の位置がシャッフルされて入っています。
      • 例:文書データセットdataset中の単語数が100個、ウインドウサイズwindowが5だった場合、self.orderは5から94までの数字がシャッフルされたnumpy.arrayになる。
    • イテレータの定義__next__は、コンストラクタのパラメータに従ってミニバッチサイズ個のTarget Word centerとContext Word contextを返します。
      • position = self.order[i:i_end]で、単語の位置をシャッフルしたリストself.orderからbatch_size分のTarget Wordのインデックスpositionを生成します。(positionは後でself.dataset.takeによってTarget Word centerに変換されます。)
      • offset = np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])で、ウインドウを表現するオフセットoffsetを作成しています。
      • pos = position[:, None] + offset[None, :]によって、それぞれのTarget Wordに対するContext Wordのインデックスposを生成します。(posは後でself.dataset.takeによってContext Word contextに変換されます。)
  • main関数
    • データ取得
      • trainvalにはそれぞれtrainingデータ、validationデータが入っています。単語のidの列で文書を表現しています。

          >>> train
          array([ 0,  1,  2, ..., 39, 26, 24], dtype=int32)
          >>> val
          array([2211,  396, 1129, ...,  108,   27,   24], dtype=int32)
        
      • trainingデータtrainに含まれる最大のid+1の値がボキャブラリ数n_vocabになります。

    • 損失関数作成
    • モデル作成
    • Optimizer作成、Trainer作成、実行を下記コードで行っています。
      • 注意してほしいのは、CPUでもGPUでも動くように、GPUの場合にだけGPU上のメモリにデータを転送する関数convertが引数converterとして渡されています。
      • また、trainer.extendによってログ出力などの機能拡張が行われています。

実行方法

$ pwd
/root2chainer/chainer/examples/word2vec
$ $ python train_word2vec.py --test  # test modeで実行。全データで学習したいときは--testを消去
GPU: -1
# unit: 100
Window: 5
Minibatch-size: 1000
# epoch: 20
Training model: skipgram
Output type: hsm

n_vocab: 10000
data length: 100
epoch       main/loss   validation/main/loss
1           4233.75     2495.33               
2           1411.14     4990.66               
3           4233.11     1247.66               
4           2821.66     4990.65               
5           4231.94     1247.66               
6           5642.04     2495.3                
7           5640.82     4990.64               
8           5639.31     2495.28               
9           2817.89     4990.62               
10          1408.03     3742.94               
11          5633.11     1247.62               
12          4221.71     2495.21               
13          4219.3      4990.56               
14          4216.57     2495.16               
15          4213.52     2495.12               
16          5616.03     1247.55               
17          5611.34     3742.78               
18          2800.31     3742.74               
19          1397.79     2494.95               
20          2794.1      3742.66

リファレンス