アップサンプリングなどで使われるConvTransposeについて,イマイチ理解できていなかったように思えるので可視化した. Convolutionといえば,conv_arithmeticのGIFが分かりやすいが,ConvTransposeに関しては,通常のConvolutionのように見え,strideやpaddingが通常のConv時とどう違うのか分かりにくいように感じたので,How PyTorch Transposed Convs1D Workを参考に書いてみた.
可視化
これがその図である.Convを逆順にたどるとConvTranposeになっているのがわかると思う. ConvTranpose時のpadding
がやや分かりにくいかもしれないが,どんだけはみ出して始めるか?(はみ出た分は後で捨てる),という考えで見ると分かりやすいかもしれない.
コード
この図をpytorchで書くとこうなる.
import torch import torch.nn as nn x = torch.ones(1,1,6) enc = nn.Conv1d(1,1,kernel_size=3,padding=2,stride=3,bias=False) enc.weight.data = torch.ones(1, 1, 3) dec = nn.ConvTranspose1d(1,1,kernel_size=3,padding=1,stride=2,bias=False) dec.weight.data = torch.ones(1, 1, 3) enc(x) h = enc(x) # shape: (1, 1, 3), data: [[[1,3,2]]] print('h:', h) out = dec(h) # shape: (1, 1, 5), data: [[[1,4,3,5,2]]] print('out:', out)
参考
// ちなみに畳み込み全然初心者なので嘘ついてたら教えて下さい.