The jonki

呼ばれて飛び出てじょじょじょじょーんき

2020年にやってよかった教材(機械学習関連の入門多め)

今年は仕事内容も変わって,いろいろなものを勉強した気がします.ということで買ってよかったもの,とは別に.やってよかった教材(書籍,オンライン教材,ブログ等)を紹介しようと思います.入門系多めです.

Andrew Ng先生

今年はAndrew Ng(アンドリュー・エング)先生の大ファンになりました.Twitter界隈でまずNg先生のCourseraやCS229をやれ,という話は時折上がってくるので見てみたらハマりました.私は2017年ぐらいからNLPを始めて,機械学習の知識は必要に応じて勉強していたので,体系だってあまり学んで来ませんでした(高村先生の機械学習入門ぐらい). そこでゼロから勉強し直そうと思い,Andrew Ng先生の教材をやってみました.

結論から言うと非常に良かったです.これから機械学習始めたいという方に特にオススメ,というかこれ以外から始めない方が良いかも,と強く言いたいぐらい良かったです.いきなりDNNに飛びついてしまう人も最近は多いのかなと思うのですが,ここでロジスティック回帰や評価の仕方などしっかり学んだほうがよいです.授業では各項目に対して深入りするわけではないですが,考え方のエッセンスのようなものをAndrew Ng先生がわかりやすく説明してくれます.これがこの授業の最大の魅力です. Courseraには確認テストがあって,選択・記入式のチェックテストとOctaveによる実装テストがあります.どちらもそんなに難しくないので,良い復習になると思います.え,Octave?と思うかもしれませんが,OctavePythonMatlabの文法に馴染みがあれば,ちょっとしたチュートリアルもあるので問題になりません.実装もロジックの部分だけ実装する形式なので,Octaveの言語に精通している必要はまったくありません. コースは11週と長く見えますが,まったくの初心者でなければ授業は1.5倍速ぐらいで,1ヶ月ぐらいで終わると思います.授業は英語ですが,リスニングに自信がなくても板書がしっかりしてるので十分ついていけると思います.日英の字幕もあります.非常に有名な授業なので,修了証の権利も購入してLinkedinに貼るつもりでやると,モチベーションも続くかもしれません.

www.coursera.org

この授業の後は,CS229というスタンフォードでの授業を見てみました.Youtubeで2018年度に行われたものが今年4月に無料で公開されています.Courseraの授業は完全に入門編でしたが,こちらは入門ではありますが,範囲をやや広げ(EMアルゴリズム強化学習等),数式的な踏み込みもかなりしています.こちらはシラバスを見てもらえればわかりますが,動画での授業だけでなく,多くのLecture NotesがPDFとしてまとまっています.動画中での導出は時間上カットしてたりするのですが,ちゃんとその導出などもLecture Notesに書いています.シラバス上はリンク切れの資料も多いのですが,世界中で人気の授業なのでGithubなどで探せばいくらでも関連資料は出てきます.またこの授業のテストもあって,Problem Setsというのがあります.こちらもネットを探せば,元の課題が転がっているので挑戦してみても良いかもしれません.が,結構重たかったので,私は動画→Lecture Note読む,で先に進んじゃいました. www.youtube.com

また本授業を担当したTAの方が俯瞰図を作っているので,どのあたり勉強できそうか眺めても面白いです. github.com

機械学習の入門が終わったら,DNNやNLPなどを勉強しても面白いかもしれません.同じStanfordでの授業はいっぱい転がっているので,自分にあったものをやると楽しそうです.

Stanford CS230: Deep Learning | Autumn 2018 | Lecture 1 - Class Introduction and Logistics - YouTube

Stanford CS224N: NLP with Deep Learning | Winter 2019 | Lecture 1 – Introduction and Word Vectors - YouTube

NLPだとCMUのGraham Neubig先生も毎年授業を即公開しており,最新の論文のキャッチアップもできて良いですね.

Graham Neubig - YouTube

線形代数

色々論文などを読む上で線形代数の基礎力がまったく足りてないなと反省し,色々線形代数もあさりました.

こちらもCourseraであった線形代数の授業です.こちらも演習とセットになっていて,短いのでよいです.

www.coursera.org

またヨビノリシリーズも良かったです.書籍は動画と基本的に同じなので,動画を見た後に書籍を見ると良い復習になります.

www.youtube.com

予備校のノリで学ぶ線形代数

予備校のノリで学ぶ線形代数

音声信号処理

今年はNLPだけでなく,音声系の内容にも手を出しました.といっても大学以来,まったくやってこなかったので基礎の基礎から色々勉強しましたが,中でも良かったものを紹介します.

まず東北大の鏡先生のディジタル信号処理はかなり良かったです.会話形で進むタイプなので,疑問に思いやすい点などしっかりとやる夫が突っ込んでくれて面白かったです.これで完全に忘れていたフーリエ系の扱いを大学時代以上に思い出せました. www.ic.is.tohoku.ac.jp

あとは本もいくつか買いましたが,サクッと入門系だとこの辺がわかりやすかったです.やる夫シリーズはFFTの解説があえてないのですが,この本はFFTの解説(というか計算)がわかりやすかったです.

道具としてのフーリエ解析

道具としてのフーリエ解析

またどうしても信号処理を扱う上で,複素数は外せないところなので,これも読みました.こういうシリーズはイラストで安心感を装い,中身はあまり...という場合が多いのですが,これは良書だと思いました.ちなみに漫画部分も普通に面白かったです笑

マンガでわかる虚数・複素数

マンガでわかる虚数・複素数

  • 作者:相知 政司
  • 発売日: 2010/11/12
  • メディア: 単行本(ソフトカバー)

次にプログラミングでデジタル信号処理できないとまずいので,こちらも一通りやりました.実際に音声ファイルを読み込んで周波数解析をして可視化して,という一連の流れをサクッと学べます. aidiary.hatenablog.com

また,音声を扱う上で,人の聴覚や発声の仕組みも知りたくなりました.デジタル信号処理は音声とは限らないので,このあたりの情報はあまり載っていません.そこで良い書籍ないかなといろいろ探していましたが,この本は非常に読み物としても面白かったです.音声系の教科書,堅苦しい(失礼)のが多いんですが,これは読みやすい上にタメになる知識が多かったので,また読みたいところ.目次見て面白そうと思ったら買いでしょう.

ゼロからはじめる音響学 (KS理工学専門書)

ゼロからはじめる音響学 (KS理工学専門書)

プログラミング

これずっと前に購入して積んでいたので読んでみたのですが,アルゴリズムの勉強を初めてするときなどに読んでおきたかったです.大学時代にアルゴリズムの授業というとまったく楽しめなかったので,こういった実用的にアルゴリズムがどこで活躍しているか分かる本に出会いたかった.

またけんちょんさんの本も買っていて,冬休み時間が取れたのでやってみました.競プロをやる上で,アルゴリズムの勉強をする→AtCoderで実力を試す,など考えられますが,体系的にアルゴリズムを俯瞰して勉強でき,更に実装にも踏み込んだ入門書はこれまでなかったように感じるので新鮮でした.入門〜緑ぐらいの人向けのイメージです.アルゴリズム何も知らんという人は,アルゴリズム図鑑から入ったほうが良いかもしれないです.

アルゴリズム図鑑 絵で見てわかる26のアルゴリズム

アルゴリズム図鑑 絵で見てわかる26のアルゴリズム

Youtuber

最後にYoutubeで見つけた良いチャンネルです.CS229しかりでYoutubeに良い教材が大量にあふれる良い時代になりました.論文の解説などもあったりするので,本当にYoutubeよく見るようになりました. ということで登録しておいて面白い動画がよく流れてくるチャンネルを共有しておきます.

www.youtube.com

www.youtube.com

www.youtube.com

良いお年を!

2020年に買ってよかったもの(日用品編・ガジェット編)

今年も振り返りしようと思います.今年は2年近くやっていたPodcastを休止し,ブログもなかなか書けませんでした.忙しかったというよりは,仕事で色々と新しい環境をエンジョイしていたので,なかなかprivateでアウトプットする場がなかったですね.

今年は日用品編・ガジェット編をまとめます.技術書編は長くなったので別記事で今日明日に出します.→書きました.

www.jonki.net

www.jonki.net www.jonki.net

日用品

今年はコロナということもあり,基本的に在宅勤務でした.その在宅勤務を支えてくれた品物が多いです.

まずは夏頃ですが,私はホットコーヒーからアイスコーヒーを作っていたので,氷をガツンと入れてストローで気軽に飲めるタンブラーを探していて見つけた商品がこれ.Klean Kanteenのタンブラーです.金属ストローが結構クセになります.蓋をしたまま飲めるのでよいです.

KEENのサンダル.春〜秋の外出するときは基本これ履いていました.フィット感が素晴らしいのでスニーカーのように履けるサンダルです.KEENのサンダルはマジでいいよ.

次は冬場に入って買った保温ポットです.1Lのポットなので卓上においても邪魔じゃないサイズで絶妙です.継ぎ足し継ぎ足しでお茶,白湯,カモミールティー飲んでます.

最後はパタゴニアのレフュジオです.最高です.サイズがそこそこあって,PCやガジェットの収納ができて,軽くて,しっかりしてる,という製品を探した先にたどり着いた商品です.アークテリクスやノースフェイスもちょっと浮気しましたが,どちらも重かったり,使い勝手悪かったりするんですよね. 仕事も旅行もコイツです.

ガジェット編

Thinkpadのキーボード使ってましたが,久しぶりにHHIKBに手を出しました.打ち心地は最高ですが,キーがやはり少なすぎるので,配列はThinkpadキーボードの方が好みです.

接続マシンを切り替えられるKVMです.在宅勤務ということもあり,日中は会社のWindows PC,仕事後は家のMacに切り替え,というのが面倒すぎて購入.ショートカットキーで切り替えられて便利です.

ラジオです.昔からラジオが結構好き(interfmしか聞かないけど)なのでつい購入.アナログチューンで音もしっかり出るので楽しいです.これも卓上に置いています.

次はサイレントギター.あまり外にもでなくなったので,大学以来,久しぶりにギターでも弾きたいなと思い,突発的に購入.スチール弦も良かったけど,クラシックギターの指弾きが好きだったのでナイロン弦のモデルを購入.これだけでも良いのですが,ヤマハさんがエレガット用の卓上アンプを出しててこれもウッカリ購入.最高か.

Apple Watchの信者と成り果てたので購入.軽くて数回は余裕で充電できるので,旅行時などに重宝しました.

また無駄にヘッドフォンを増やしました.用途に合わせて色々使っているんですが,音楽をしっかり聞くというよりは最近はかけ流しのときが多いので,開放型以外のヘッドフォンはAirpods Proを除いてなかなか使わなくなりました.MDR-MA900というソニーの開放型で軽いヘッドフォンがあるのに買ったのはご愛嬌で(次世代機出ないかなぁ).見た目はそれなりにごっついですが,軽くてつけ心地はよいので冬にピッタリのヘッドフォンです.ATHシリーズは大学時代に始めて自分で買ったヘッドフォン以来,2つ目ですが相変わらず良いですね.

audio-technica エアーダイナミック オープン型ヘッドホン ATH-AD900X

audio-technica エアーダイナミック オープン型ヘッドホン ATH-AD900X

  • 発売日: 2020/08/04
  • メディア: エレクトロニクス

今年はそこまで大物は買わなかったですかね. このリストの中から1つでも響くものがあれば幸いです.それでは.

pudbをもっともっと活用する

以前,pudbの記事を書きましたが,あの時より更に使うようになっていたので,更に色々と便利な機能を紹介したいと思います. pudb?という方は下記の記事をまず御覧ください.今回は発展編ですが,前回の記事の続きというだけで,別に難しいことはありません. www.jonki.net

移動編

  • V, S, B, CCtrl-X での移動

基本的な移動操作になります.Variables, Stack, Breakpoints, Codeのフィールドを自由に移動できます.Ctrl-xでちょっとしたコードを動かしたり確認したりします.各フィールドでの上下移動はj/kが使えます.

  • H: 現在の実行ラインまで戻る

デバッグ中に色々移動したあとに,現在の実行ラインまで戻ります.

  • u/d: スタックの上下移動

スタックトレースを行き来できます.ライブラリの実行などで層が深くなった場合でも,簡単にスタックレイヤー層を上下に移動できます.Stackフィールドにおいて,直接行きたいレイヤー層を選択することが可能です.

実行編

  • t: カーソルのある行まで実行

ブレークポイントを貼って実行 bcでもいけますが,無駄にブレークポイントが増えて不便です.これを覚えておくと地味に便利です.

  • f : 現在の関数の最後まで行く

関数最後までサクッと移動できます.tなどでも代用できますが,「このどうでも良い関数はいいから次!」っていうときに便利です.

Variablesフィールド

  • n: 変数ウォッチ

pudbでも変数のウォッチは可能です!特定条件でのデバッグなどにとても強力です.下記ではミニバッチのインデックスに対してウォッチ変数を作ってみました.

  • custom stringifier設定

Variablesフィールドの表示方法を変更できます.例えば無理やりPytorchのTensorのshapeを表示するようにするとしましょう.~/.config/pudb/custom_stringifier.pyを用意します.

def pudb_stringifier(obj):
    try:
        return (list(obj.shape), 'tensor')
    except:
        return type(obj)

Ctrl-pでVariablesフィールドの設定で,Customのところに先程のファイルを指定します.

そうすることで,Variablesフィールドの各変数を選択した状態でcを押すと,先程のコードが走り,PytorchのTensor変数のshapeが表示されるようになりました.tで型表示,rで中の値を展開という感じにいろいろな表示を切り替えられます.ちなみに先程の設定では,デフォルトはtypeにチェックが入っていますが,Customのところにチェックを入れれば,自動でcキーを押した状態となり,PytorchのTensorであれば,勝手にshapeが表示されます.

Breakpointフィールド

先程のウォッチ変数と似ていますが,ブレークポイントの発動条件を設定できます.コードがある程度進まないと発生しないバグなどの調査時などに便利です.ちなみにdブレークポイントは消せます.

カスタムテーマ

実はカラーテーマも設定できます.方法としては,先程のVariablesフィールドの設定と同様,テーマ設定ファイルを用意して,Ctrl-pで設定画面に行き,ThemeのCustomのところに,そのテーマファイルを指定すればよいです.

公式がテーマのサンプルを公開しています.paletteという辞書型の変数を更新すればよいだけです.フルスクラッチだと辛いので,公式が用意しているテーマからコピペし,気に入らないところだけを書き換えるのが楽です.指定できるカラーは,基本的にXterm colorの256色からです.

例えばブレークポイントのマークとブレークポイント箇所ラインの背景色を変えてみるとこんな感じに変わります.

まとめ

前回の基本的なコマンドに対して,今回はかなり実践的な機能を多く紹介してみました.ショートカットはデフォルトで覚えやすいようになっているので,すぐに覚えられると思います.カスタム機能なども色々いじれて面白いですね.

Enjoy coding!

EMアルゴリズムの勉強メモ

もう何度となく勉強しているであろうEMアルゴリズム,いい加減忘れっぽい正確なので勉強ノートを取った. EMアルゴリズムがどのようなものか,北先生の確率的言語モデルの教科書を使ってノートを取っています.この本はとても良いので,この本を読んで頂くのが一番早いのですが,数式展開をちょっと丁寧にしつつ,自分用なので私の理解も添えて書いています. ちなみに高村先生の本も混合ガウス分布の流れから説明しており,これもとてもわかり易いです.

github.com

言語と計算 (4) 確率的言語モデル

言語と計算 (4) 確率的言語モデル

Seq2seqモデルのBeam Search Decoding (Pytorch)

この記事では,Pytorchで作ったseq2seq型の翻訳モデルを使って,ビームサーチによるデコーディングをします. OpenNMTfairseqを使えば簡単に利用できるのですが,ビームサーチのためだけにこのようなフレームワークを使うのはちょっとなぁ,ということと,ビームサーチ自体はDNNに限らず様々な場面で役に立つ手法なので,この際ピュアに実装してみた,というのがこの記事です.

ちなみに一般的なseq2seqのデコードは,各タイムステップで予測したtop-1の単語を,次ステップのデコーダーの入力に使います. ビームサーチでは,このようなgreedyな条件を緩め,上位K個の予測を使って,デコードしていきます.ビームサーチをよく知らんという方は,Andrew Ngの神説明が参考になると思います. C5W3L08 Attention Model, Andrew Ng.

できたもの

  • seq2seq (rnn) w/wo attentionの翻訳器のビームサーチによるデコーダ

バッチを分解して1文ずつデコードするbeam_search_decodingと(できるだけ)バッチ処理するbatch_beam_search_decodingがあります.出力は同じなので基本的に高速な後者を使うべきですが,理解のため最初の関数を書きました.

※今回始めてビームサーチを書いたのでまだまだイケてないところや変な箇所が多いと思うのでこのリポジトリは修正していく予定です.この記事では,コミット4cb1187760f3a9e7をベースに記事を書きます.

※バグや誤っているところ等,少なからずあると思うのでバグレポートしていただけると嬉しいです.

動かし方はGithubのREADMEを見てください.ビームサーチは推論時(テスト時)に動作するようになっています. github.com

解説

ここからは興味ある方へ,実装したビームサーチ部分を解説します.もっと良い方法あるかなぁと思いつつ,今回はこの方法で書きました.

beam_search_decoding (1文ずつデコード)

ビームサーチの実装は,下記の方の実装を参考にしました.OpenNMTの実装などもあるのですが,まずはシンプルな方法で実装している方を参考にしました. この実装は,1文ずつデコードしているため実装がシンプルです(遅いけど).また各ステップでデコードするときに,どの出力をデコードするか,を選ぶ必要があるのですが,この実装ではヒープ(heapq)を利用しています. 各デコードの結果はBeamSearchNodeという構造体にしているのですが,その状態に至ったときのスコア(対数確率の和)を負にしたものとあわせてヒープに登録しておけば,このヒープから取り出されるノードは,現状で最もスコアが高いノードになります.なのでデコード時は,このヒープからpopしていくだけで良いので高速(O(1))で楽ですね.一方でどんどんヒープにノードが追加されるのに対して,枝刈りによるメモリ解放は行わないため,メモリ効率は良くないです.

ちなみに上位K個(K: ビームサイズ)の結果の取得は,pytorchであればtorch.topk関数で簡単に,値とその引数を取得できるので便利です.

github.com

def beam_search_decoding(decoder,
                         enc_outs,
                         enc_last_h,
                         beam_width,
                         n_best,
                         sos_token,
                         eos_token,
                         max_dec_steps,
                         device):
    """Beam Seach Decoding for RNN
    Args:
        decoder: An RNN decoder model
        enc_outs: A sequence of encoded input. (T, bs, 2H). 2H for bidirectional
        enc_last_h: (bs, H)
        beam_width: Beam search width
        n_best: The number of output sequences for each input
    Returns:
        n_best_list: Decoded N-best results. (bs, T)
    """

    assert beam_width >= n_best

    n_best_list = []
    bs = enc_outs.shape[1]

    # Decoding goes sentence by sentence.
    # So this process is very slow compared to batch decoding process.
    for batch_id in range(bs):
        # Get last encoder hidden state
        decoder_hidden = enc_last_h[batch_id] # (H)
        enc_out = enc_outs[:, batch_id].unsqueeze(1) # (T, 1, 2H)

        # Prepare first token for decoder
        decoder_input = torch.tensor([sos_token]).long().to(device) # (1)

        # Number of sentence to generate
        end_nodes = []

        # starting node
        node = BeamSearchNode(h=decoder_hidden, prev_node=None, wid=decoder_input, logp=0, length=1)

        # whole beam search node graph
        nodes = []

        # Start the queue
        heappush(nodes, (-node.eval(), id(node), node))
        n_dec_steps = 0

        # Start beam search
        while True:
            # Give up when decoding takes too long
            if n_dec_steps > max_dec_steps:
                break

            # Fetch the best node
            score, _, n = heappop(nodes)
            decoder_input = n.wid
            decoder_hidden = n.h

            if n.wid.item() == eos_token and n.prev_node is not None:
                end_nodes.append((score, id(n), n))
                # If we reached maximum # of sentences required
                if len(end_nodes) >= n_best:
                    break
                else:
                    continue

            # Decode for one step using decoder
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden.unsqueeze(0), enc_out)

            # Get top-k from this decoded result
            topk_log_prob, topk_indexes = torch.topk(decoder_output, beam_width) # (1, bw), (1, bw)
            # Then, register new top-k nodes
            for new_k in range(beam_width):
                decoded_t = topk_indexes[0][new_k].view(1) # (1)
                logp = topk_log_prob[0][new_k].item() # float log probability val

                node = BeamSearchNode(h=decoder_hidden.squeeze(0),
                                      prev_node=n,
                                      wid=decoded_t,
                                      logp=n.logp+logp,
                                      length=n.length+1)
                heappush(nodes, (-node.eval(), id(node), node))
            n_dec_steps += beam_width

        # if there are no end_nodes, retrieve best nodes (they are probably truncated)
        if len(end_nodes) == 0:
            end_nodes = [heappop(nodes) for _ in range(beam_width)]

        # Construct sequences from end_nodes
        n_best_seq_list = []
        for score, _id, n in sorted(end_nodes, key=lambda x: x[0]):
            sequence = [n.wid.item()]
            # back trace from end node
            while n.prev_node is not None:
                n = n.prev_node
                sequence.append(n.wid.item())
            sequence = sequence[::-1] # reverse

            n_best_seq_list.append(sequence)

        n_best_list.append(n_best_seq_list)

    return n_best_list

batch_beam_search_decoding (バッチデコード)

次に上記コードをGPUの恩恵を受けれるように,バッチ状態のまま解けるようにしたいと思います.ただここの方法は自明でなく,正直良くわからなかったです. というのも今回の実装では,各文でのビームサーチによる探索グラフ(ヒープ)は異なってきますし,探索終了タイミングもそれぞれ異なってきます. そのため私のコードでは,バッチ化できたのはRNN デコーダーに実際に投げるところです.このRNNデコーダーに投げるデータを作るため,バッチサイズ分のforループを回し,バッチデータを作っているのでこの部分は結局遅いです.また,デコードステップを回していく上で,探索を終了した文が発生しますが,データのshapeは維持されていたほうが使いやすいため,探索が終了した事例もひたすら再利用されているのが残念なところ.

def batch_beam_search_decoding(decoder,
                               enc_outs,
                               enc_last_h,
                               beam_width,
                               n_best,
                               sos_token,
                               eos_token,
                               max_dec_steps,
                               device):
    """Batch Beam Seach Decoding for RNN
    Args:
        decoder: An RNN decoder model
        enc_outs: A sequence of encoded input. (T, bs, 2H). 2H for bidirectional
        enc_last_h: (bs, H)
        beam_width: Beam search width
        n_best: The number of output sequences for each input
    Returns:
        n_best_list: Decoded N-best results. (bs, T)
    """

    assert beam_width >= n_best

    n_best_list = []
    bs = enc_last_h.shape[0]

    # Get last encoder hidden state
    decoder_hidden = enc_last_h # (bs, H)

    # Prepare first token for decoder
    decoder_input = torch.tensor([sos_token]).repeat(1, bs).long().to(device) # (1, bs)

    # Number of sentence to generate
    end_nodes_list = [[] for _ in range(bs)]

    # whole beam search node graph
    nodes = [[] for _ in range(bs)]

    # Start the queue
    for bid in range(bs):
        # starting node
        node = BeamSearchNode(h=decoder_hidden[bid], prev_node=None, wid=decoder_input[:, bid], logp=0, length=1)
        heappush(nodes[bid], (-node.eval(), id(node), node))

    # Start beam search
    fin_nodes = set()
    history = [None for _ in range(bs)]
    n_dec_steps_list = [0 for _ in range(bs)]
    while len(fin_nodes) < bs:
        # Fetch the best node
        decoder_input, decoder_hidden = [], []
        for bid in range(bs):
            if bid not in fin_nodes and n_dec_steps_list[bid] > max_dec_steps:
                fin_nodes.add(bid)

            if bid in fin_nodes:
                score, n = history[bid] # dummy for data consistency
            else:
                score, _, n = heappop(nodes[bid])
                if n.wid.item() == eos_token and n.prev_node is not None:
                    end_nodes_list[bid].append((score, id(n), n))
                    # If we reached maximum # of sentences required
                    if len(end_nodes_list[bid]) >= n_best:
                        fin_nodes.add(bid)
                history[bid] = (score, n)
            decoder_input.append(n.wid)
            decoder_hidden.append(n.h)

        decoder_input = torch.cat(decoder_input).to(device) # (bs)
        decoder_hidden = torch.stack(decoder_hidden, 0).to(device) # (bs, H)

        # Decode for one step using decoder
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, enc_outs) # (bs, V), (bs, H)

        # Get top-k from this decoded result
        topk_log_prob, topk_indexes = torch.topk(decoder_output, beam_width) # (bs, bw), (bs, bw)
        # Then, register new top-k nodes
        for bid in range(bs):
            if bid in fin_nodes:
                continue
            score, n = history[bid]
            if n.wid.item() == eos_token and n.prev_node is not None:
                continue
            for new_k in range(beam_width):
                decoded_t = topk_indexes[bid][new_k].view(1) # (1)
                logp = topk_log_prob[bid][new_k].item() # float log probability val

                node = BeamSearchNode(h=decoder_hidden[bid],
                                      prev_node=n,
                                      wid=decoded_t,
                                      logp=n.logp+logp,
                                      length=n.length+1)
                heappush(nodes[bid], (-node.eval(), id(node), node))
            n_dec_steps_list[bid] += beam_width

    # Construct sequences from end_nodes
    # if there are no end_nodes, retrieve best nodes (they are probably truncated)
    for bid in range(bs):
        if len(end_nodes_list[bid]) == 0:
            end_nodes_list[bid] = [heappop(nodes[bid]) for _ in range(beam_width)]

        n_best_seq_list = []
        for score, _id, n in sorted(end_nodes_list[bid], key=lambda x: x[0]):
            sequence = [n.wid.item()]
            while n.prev_node is not None:
                n = n.prev_node
                sequence.append(n.wid.item())
            sequence = sequence[::-1] # reverse

            n_best_seq_list.append(sequence)

        n_best_list.append(copy.copy(n_best_seq_list))

    return n_best_list

速度比較

バッチ版のビームサーチは実装がかなり汚いですが,実行速度はどの程度恩恵があるか調べてみました.バッチ数は128でビーム幅は10です. 非バッチ版の実装と比べて,翻訳結果は変わらないが,速度としては2,3倍早くなっているのがわかります(バッチサイズとビームサイズで変わりますが). それでもまだ結構遅いなという印象ですが,まぁまぁ最初はこんなものでしょうと自分に言い訳します.

% python run.py --attention --skip_train --model_path ./ckpts/s2s-attn.pt
Number of training examples: 29000
Number of validation examples: 1014
Number of testing examples: 1000
Unique tokens in source (de) vocabulary: 7855
Unique tokens in target (en) vocabulary: 5893

In: <SOS> . schnee den über laufen hunde mittelgroße zwei <EOS>
for loop beam search time: 8.718
Out: Rank-1: <SOS> two medium brown dogs run across the snow . the snow . <EOS>
Out: Rank-2: <SOS> two medium brown dogs run across the snow . <EOS>
Out: Rank-3: <SOS> two medium brown dogs run across the snow . the snow . . <EOS>
Out: Rank-4: <SOS> two medium brown dogs run across the snow . . <EOS>
Out: Rank-5: <SOS> two medium brown dogs run across the snow . snow . <EOS>
Batch beam search time: 2.994
Out: Rank-1: <SOS> two medium brown dogs run across the snow . the snow . <EOS>
Out: Rank-2: <SOS> two medium brown dogs run across the snow . <EOS>
Out: Rank-3: <SOS> two medium brown dogs run across the snow . the snow . . <EOS>
Out: Rank-4: <SOS> two medium brown dogs run across the snow . . <EOS>
Out: Rank-5: <SOS> two medium brown dogs run across the snow . snow . <EOS>

In: <SOS> . <unk> mit tüten gehsteig einem auf verkauft frau eine <EOS>
for loop beam search time: 9.654
Out: Rank-1: <SOS> a woman is selling on her
Out: Rank-2: <SOS> a woman woman selling a
Out: Rank-3: <SOS> a woman is her selling
Out: Rank-4: <SOS> a woman is selling vegetables on a sidewalk
Out: Rank-5: <SOS> a woman woman selling rice
Out: Rank-6: <SOS> a woman is selling her on
Out: Rank-7: <SOS> a woman is selling watermelon on a
Out: Rank-8: <SOS> a woman is selling on the
Out: Rank-9: <SOS> a woman is sells selling
Out: Rank-10: <SOS> a woman selling selling on
Batch beam search time: 3.256
Out: Rank-1: <SOS> a woman is selling on her
Out: Rank-2: <SOS> a woman woman selling a
Out: Rank-3: <SOS> a woman is her selling
Out: Rank-4: <SOS> a woman is selling vegetables on a sidewalk
Out: Rank-5: <SOS> a woman woman selling rice
Out: Rank-6: <SOS> a woman is selling her on
Out: Rank-7: <SOS> a woman is selling watermelon on a
Out: Rank-8: <SOS> a woman is selling on the
Out: Rank-9: <SOS> a woman is sells selling
Out: Rank-10: <SOS> a woman selling selling on

まとめ

今回ビームサーチを生実装してみました.ヒープによる実装で楽をしましたが,何となく動作イメージは掴めてきました.スコアの正規化や効率的なバッチデコーディングにより改善点はまだまだありそうですが,今回はこんなところで.詳しい方はアドバイス求む!