この記事では,Pytorchで作ったseq2seq型の翻訳モデルを使って,ビームサーチによるデコーディングをします. OpenNMTやfairseqを使えば簡単に利用できるのですが,ビームサーチのためだけにこのようなフレームワークを使うのはちょっとなぁ,ということと,ビームサーチ自体はDNNに限らず様々な場面で役に立つ手法なので,この際ピュアに実装してみた,というのがこの記事です.
ちなみに一般的なseq2seqのデコードは,各タイムステップで予測したtop-1の単語を,次ステップのデコーダーの入力に使います. ビームサーチでは,このようなgreedyな条件を緩め,上位K個の予測を使って,デコードしていきます.ビームサーチをよく知らんという方は,Andrew Ngの神説明が参考になると思います. C5W3L08 Attention Model, Andrew Ng.
できたもの
- seq2seq (rnn) w/wo attentionの翻訳器のビームサーチによるデコーダー
バッチを分解して1文ずつデコードするbeam_search_decoding
と(できるだけ)バッチ処理するbatch_beam_search_decoding
があります.出力は同じなので基本的に高速な後者を使うべきですが,理解のため最初の関数を書きました.
※今回始めてビームサーチを書いたのでまだまだイケてないところや変な箇所が多いと思うのでこのリポジトリは修正していく予定です.この記事では,コミット4cb1187760f3a9e7
をベースに記事を書きます.
※バグや誤っているところ等,少なからずあると思うのでバグレポートしていただけると嬉しいです.
動かし方はGithubのREADMEを見てください.ビームサーチは推論時(テスト時)に動作するようになっています. github.com
解説
ここからは興味ある方へ,実装したビームサーチ部分を解説します.もっと良い方法あるかなぁと思いつつ,今回はこの方法で書きました.
beam_search_decoding (1文ずつデコード)
ビームサーチの実装は,下記の方の実装を参考にしました.OpenNMTの実装などもあるのですが,まずはシンプルな方法で実装している方を参考にしました.
この実装は,1文ずつデコードしているため実装がシンプルです(遅いけど).また各ステップでデコードするときに,どの出力をデコードするか,を選ぶ必要があるのですが,この実装ではヒープ(heapq)を利用しています.
各デコードの結果はBeamSearchNode
という構造体にしているのですが,その状態に至ったときのスコア(対数確率の和)を負にしたものとあわせてヒープに登録しておけば,このヒープから取り出されるノードは,現状で最もスコアが高いノードになります.なのでデコード時は,このヒープからpopしていくだけで良いので高速(O(1))で楽ですね.一方でどんどんヒープにノードが追加されるのに対して,枝刈りによるメモリ解放は行わないため,メモリ効率は良くないです.
ちなみに上位K個(K: ビームサイズ)の結果の取得は,pytorchであればtorch.topk
関数で簡単に,値とその引数を取得できるので便利です.
def beam_search_decoding(decoder, enc_outs, enc_last_h, beam_width, n_best, sos_token, eos_token, max_dec_steps, device): """Beam Seach Decoding for RNN Args: decoder: An RNN decoder model enc_outs: A sequence of encoded input. (T, bs, 2H). 2H for bidirectional enc_last_h: (bs, H) beam_width: Beam search width n_best: The number of output sequences for each input Returns: n_best_list: Decoded N-best results. (bs, T) """ assert beam_width >= n_best n_best_list = [] bs = enc_outs.shape[1] # Decoding goes sentence by sentence. # So this process is very slow compared to batch decoding process. for batch_id in range(bs): # Get last encoder hidden state decoder_hidden = enc_last_h[batch_id] # (H) enc_out = enc_outs[:, batch_id].unsqueeze(1) # (T, 1, 2H) # Prepare first token for decoder decoder_input = torch.tensor([sos_token]).long().to(device) # (1) # Number of sentence to generate end_nodes = [] # starting node node = BeamSearchNode(h=decoder_hidden, prev_node=None, wid=decoder_input, logp=0, length=1) # whole beam search node graph nodes = [] # Start the queue heappush(nodes, (-node.eval(), id(node), node)) n_dec_steps = 0 # Start beam search while True: # Give up when decoding takes too long if n_dec_steps > max_dec_steps: break # Fetch the best node score, _, n = heappop(nodes) decoder_input = n.wid decoder_hidden = n.h if n.wid.item() == eos_token and n.prev_node is not None: end_nodes.append((score, id(n), n)) # If we reached maximum # of sentences required if len(end_nodes) >= n_best: break else: continue # Decode for one step using decoder decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden.unsqueeze(0), enc_out) # Get top-k from this decoded result topk_log_prob, topk_indexes = torch.topk(decoder_output, beam_width) # (1, bw), (1, bw) # Then, register new top-k nodes for new_k in range(beam_width): decoded_t = topk_indexes[0][new_k].view(1) # (1) logp = topk_log_prob[0][new_k].item() # float log probability val node = BeamSearchNode(h=decoder_hidden.squeeze(0), prev_node=n, wid=decoded_t, logp=n.logp+logp, length=n.length+1) heappush(nodes, (-node.eval(), id(node), node)) n_dec_steps += beam_width # if there are no end_nodes, retrieve best nodes (they are probably truncated) if len(end_nodes) == 0: end_nodes = [heappop(nodes) for _ in range(beam_width)] # Construct sequences from end_nodes n_best_seq_list = [] for score, _id, n in sorted(end_nodes, key=lambda x: x[0]): sequence = [n.wid.item()] # back trace from end node while n.prev_node is not None: n = n.prev_node sequence.append(n.wid.item()) sequence = sequence[::-1] # reverse n_best_seq_list.append(sequence) n_best_list.append(n_best_seq_list) return n_best_list
batch_beam_search_decoding (バッチデコード)
次に上記コードをGPUの恩恵を受けれるように,バッチ状態のまま解けるようにしたいと思います.ただここの方法は自明でなく,正直良くわからなかったです. というのも今回の実装では,各文でのビームサーチによる探索グラフ(ヒープ)は異なってきますし,探索終了タイミングもそれぞれ異なってきます. そのため私のコードでは,バッチ化できたのはRNN デコーダーに実際に投げるところです.このRNNデコーダーに投げるデータを作るため,バッチサイズ分のforループを回し,バッチデータを作っているのでこの部分は結局遅いです.また,デコードステップを回していく上で,探索を終了した文が発生しますが,データのshapeは維持されていたほうが使いやすいため,探索が終了した事例もひたすら再利用されているのが残念なところ.
def batch_beam_search_decoding(decoder, enc_outs, enc_last_h, beam_width, n_best, sos_token, eos_token, max_dec_steps, device): """Batch Beam Seach Decoding for RNN Args: decoder: An RNN decoder model enc_outs: A sequence of encoded input. (T, bs, 2H). 2H for bidirectional enc_last_h: (bs, H) beam_width: Beam search width n_best: The number of output sequences for each input Returns: n_best_list: Decoded N-best results. (bs, T) """ assert beam_width >= n_best n_best_list = [] bs = enc_last_h.shape[0] # Get last encoder hidden state decoder_hidden = enc_last_h # (bs, H) # Prepare first token for decoder decoder_input = torch.tensor([sos_token]).repeat(1, bs).long().to(device) # (1, bs) # Number of sentence to generate end_nodes_list = [[] for _ in range(bs)] # whole beam search node graph nodes = [[] for _ in range(bs)] # Start the queue for bid in range(bs): # starting node node = BeamSearchNode(h=decoder_hidden[bid], prev_node=None, wid=decoder_input[:, bid], logp=0, length=1) heappush(nodes[bid], (-node.eval(), id(node), node)) # Start beam search fin_nodes = set() history = [None for _ in range(bs)] n_dec_steps_list = [0 for _ in range(bs)] while len(fin_nodes) < bs: # Fetch the best node decoder_input, decoder_hidden = [], [] for bid in range(bs): if bid not in fin_nodes and n_dec_steps_list[bid] > max_dec_steps: fin_nodes.add(bid) if bid in fin_nodes: score, n = history[bid] # dummy for data consistency else: score, _, n = heappop(nodes[bid]) if n.wid.item() == eos_token and n.prev_node is not None: end_nodes_list[bid].append((score, id(n), n)) # If we reached maximum # of sentences required if len(end_nodes_list[bid]) >= n_best: fin_nodes.add(bid) history[bid] = (score, n) decoder_input.append(n.wid) decoder_hidden.append(n.h) decoder_input = torch.cat(decoder_input).to(device) # (bs) decoder_hidden = torch.stack(decoder_hidden, 0).to(device) # (bs, H) # Decode for one step using decoder decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, enc_outs) # (bs, V), (bs, H) # Get top-k from this decoded result topk_log_prob, topk_indexes = torch.topk(decoder_output, beam_width) # (bs, bw), (bs, bw) # Then, register new top-k nodes for bid in range(bs): if bid in fin_nodes: continue score, n = history[bid] if n.wid.item() == eos_token and n.prev_node is not None: continue for new_k in range(beam_width): decoded_t = topk_indexes[bid][new_k].view(1) # (1) logp = topk_log_prob[bid][new_k].item() # float log probability val node = BeamSearchNode(h=decoder_hidden[bid], prev_node=n, wid=decoded_t, logp=n.logp+logp, length=n.length+1) heappush(nodes[bid], (-node.eval(), id(node), node)) n_dec_steps_list[bid] += beam_width # Construct sequences from end_nodes # if there are no end_nodes, retrieve best nodes (they are probably truncated) for bid in range(bs): if len(end_nodes_list[bid]) == 0: end_nodes_list[bid] = [heappop(nodes[bid]) for _ in range(beam_width)] n_best_seq_list = [] for score, _id, n in sorted(end_nodes_list[bid], key=lambda x: x[0]): sequence = [n.wid.item()] while n.prev_node is not None: n = n.prev_node sequence.append(n.wid.item()) sequence = sequence[::-1] # reverse n_best_seq_list.append(sequence) n_best_list.append(copy.copy(n_best_seq_list)) return n_best_list
速度比較
バッチ版のビームサーチは実装がかなり汚いですが,実行速度はどの程度恩恵があるか調べてみました.バッチ数は128でビーム幅は10です. 非バッチ版の実装と比べて,翻訳結果は変わらないが,速度としては2,3倍早くなっているのがわかります(バッチサイズとビームサイズで変わりますが). それでもまだ結構遅いなという印象ですが,まぁまぁ最初はこんなものでしょうと自分に言い訳します.
% python run.py --attention --skip_train --model_path ./ckpts/s2s-attn.pt Number of training examples: 29000 Number of validation examples: 1014 Number of testing examples: 1000 Unique tokens in source (de) vocabulary: 7855 Unique tokens in target (en) vocabulary: 5893 In: <SOS> . schnee den über laufen hunde mittelgroße zwei <EOS> for loop beam search time: 8.718 Out: Rank-1: <SOS> two medium brown dogs run across the snow . the snow . <EOS> Out: Rank-2: <SOS> two medium brown dogs run across the snow . <EOS> Out: Rank-3: <SOS> two medium brown dogs run across the snow . the snow . . <EOS> Out: Rank-4: <SOS> two medium brown dogs run across the snow . . <EOS> Out: Rank-5: <SOS> two medium brown dogs run across the snow . snow . <EOS> Batch beam search time: 2.994 Out: Rank-1: <SOS> two medium brown dogs run across the snow . the snow . <EOS> Out: Rank-2: <SOS> two medium brown dogs run across the snow . <EOS> Out: Rank-3: <SOS> two medium brown dogs run across the snow . the snow . . <EOS> Out: Rank-4: <SOS> two medium brown dogs run across the snow . . <EOS> Out: Rank-5: <SOS> two medium brown dogs run across the snow . snow . <EOS> In: <SOS> . <unk> mit tüten gehsteig einem auf verkauft frau eine <EOS> for loop beam search time: 9.654 Out: Rank-1: <SOS> a woman is selling on her Out: Rank-2: <SOS> a woman woman selling a Out: Rank-3: <SOS> a woman is her selling Out: Rank-4: <SOS> a woman is selling vegetables on a sidewalk Out: Rank-5: <SOS> a woman woman selling rice Out: Rank-6: <SOS> a woman is selling her on Out: Rank-7: <SOS> a woman is selling watermelon on a Out: Rank-8: <SOS> a woman is selling on the Out: Rank-9: <SOS> a woman is sells selling Out: Rank-10: <SOS> a woman selling selling on Batch beam search time: 3.256 Out: Rank-1: <SOS> a woman is selling on her Out: Rank-2: <SOS> a woman woman selling a Out: Rank-3: <SOS> a woman is her selling Out: Rank-4: <SOS> a woman is selling vegetables on a sidewalk Out: Rank-5: <SOS> a woman woman selling rice Out: Rank-6: <SOS> a woman is selling her on Out: Rank-7: <SOS> a woman is selling watermelon on a Out: Rank-8: <SOS> a woman is selling on the Out: Rank-9: <SOS> a woman is sells selling Out: Rank-10: <SOS> a woman selling selling on
まとめ
今回ビームサーチを生実装してみました.ヒープによる実装で楽をしましたが,何となく動作イメージは掴めてきました.スコアの正規化や効率的なバッチデコーディングにより改善点はまだまだありそうですが,今回はこんなところで.詳しい方はアドバイス求む!