The jonki

呼ばれて飛び出てじょじょじょじょーんき

torch.nn.ConvTransposeの可視化

アップサンプリングなどで使われるConvTransposeについて,イマイチ理解できていなかったように思えるので可視化した. Convolutionといえば,conv_arithmeticのGIFが分かりやすいが,ConvTransposeに関しては,通常のConvolutionのように見え,strideやpaddingが通常のConv時とどう違うのか分かりにくいように感じたので,How PyTorch Transposed Convs1D Workを参考に書いてみた.

可視化

これがその図である.Convを逆順にたどるとConvTranposeになっているのがわかると思う. ConvTranpose時のpaddingがやや分かりにくいかもしれないが,どんだけはみ出して始めるか?(はみ出た分は後で捨てる),という考えで見ると分かりやすいかもしれない.

f:id:jonki:20210104211215p:plain

  • 純化のためバッチサイズ及びチャンネル数は1,畳み込み時のチャンネル出力も1に固定.
  • データの流れがわかりやすくなるので,入力データ及びカーネルの値はすべて1に固定.
  • 今回1dだが,2dでも同様.

コード

この図をpytorchで書くとこうなる.

import torch
import torch.nn as nn

x = torch.ones(1,1,6)

enc = nn.Conv1d(1,1,kernel_size=3,padding=2,stride=3,bias=False)
enc.weight.data = torch.ones(1, 1, 3)
dec = nn.ConvTranspose1d(1,1,kernel_size=3,padding=1,stride=2,bias=False)
dec.weight.data = torch.ones(1, 1, 3)

enc(x)
h = enc(x) # shape: (1, 1, 3), data: [[[1,3,2]]]
print('h:', h)
out = dec(h) # shape: (1, 1, 5), data: [[[1,4,3,5,2]]]
print('out:', out)

参考

medium.com

// ちなみに畳み込み全然初心者なので嘘ついてたら教えて下さい.