meowの覚え書き

write to think, create to understand

PyTorchのnn.ConvTranspose2dに与えるパラメータは畳み込みから逆算して考える

f:id:meow_memow:20201229144751j:plain

(畳み込みの画像はこちらのもの)

この記事では、転置畳み込み層のPyTorch実装であるnn.ConvTranspose2dの出力サイズを自分が狙った通りに生成できるように、パラメータを与える知見を共有する。

画像を生成するDNNモデルにおいてアップサンプリングは不可欠な要素である。アップサンプリングを行う時は転置畳み込み層が用いられることが多い。PyTorchで2次元の転置畳み込み層を扱いたい時はnn.ConvTranspose2dを使う。
しかし、いざnn.ConvTranspose2dを使って自分で画像生成のモデリングをしようとした時、公式ドキュメントを見ても引数の値に何を設定すればわからなかった。正確に言うと、動かして処理の内容を推測しようとしたが、挙動が意味不明だった。例えば出力の特徴マップの幅を増やしたいのでpaddingを増やすと逆に幅が小さくなるなどである。
そこで、調べ物をして、どうパラメータを設定すれば想定サイズの特徴マップを生成できるか調査した。

結論としては「畳み込みの逆問題を解く」なのだが、そこに至るまでの過程を省略して理解するのはおそらく難しいので、下記で導入を交えながら説明をする。


スポンサードリンク

導入: 畳み込み層における逆伝播

調査の結果、パラメータを想定通りに設定するには、畳み込み層の逆伝播で何が行われているかの理解が必要であるという結論に至った。 その根拠としては、参考文献[2]によると、TensorFlowに関して、畳み込み層の逆伝播で転置畳み込み層が使われていることを述べている。PyTorchで同様の実装が行われているかはわからなかったが、転置畳み込み層クラスの引数の意味は共通している。

「畳み込み層の逆伝播で転置畳み込みが行われる」とはどういうことか、例を交えて説明する。

畳み込み層の順伝播

例えば下図のようなCNNの一部があったとして、ある畳み込み層Conv1に着目する。このConv1はConv0の特徴マップを入力として受け取る。また、Conv1の出力の特徴マップはConv2へ渡すとする。

f:id:meow_memow:20201229145612j:plain

順伝播ではご存知の通り、特徴マップXカーネルFで畳み込みの計算を行い、新たな特徴マップOを生成するという処理である。 このConv1では、Fが2x2の時、padding=0, stride=1とする。 簡単のためにchannelはinput,outputともに1とし、biasはなく、活性化関数もないとする。Xのサイズが3x3のとき、Oのサイズは2x2となる

f:id:meow_memow:20201229150108j:plain

畳み込み層の逆伝播

f:id:meow_memow:20201229150345j:plain

次に、この畳み込み層の逆伝播を考える。
逆伝播での勾配計算というと、重み(畳み込み層ではカーネルとバイアス)に対する勾配をまずイメージするかもしれない。 しかし、勾配は重みに対してだけでなく、入力の特徴マップに対しても求める必要がある。何故ならば、Conv1にとっての畳み込み層の入力の特徴マップの勾配\frac{\partial  L}{\partial X}は、1つ手前の畳み込み層Conv0にとっての局所的な誤差として逆伝播するからである。 この、入力の特徴マップの勾配\frac{\partial  L}{\partial X}の求め方を簡単に説明する。詳しくは文献[1],[2]を参照いただきたい。

上図のように、逆伝播時はConv2層から誤差\frac{\partial  L}{\partial  O}が入ってくる。

ここで、Xの最小要素であるピクセル1X _ iが誤差Lに及ぼした影響\frac{\partial  L}{\partial X _ i}は、多変数関数の連鎖律[5]より、

\frac{\partial  L}{\partial X _ i} = \sum_{k} \frac{\partial  L}{\partial  O _ k} \cdot \frac{\partial  O _ k}{\partial  X _ i}

と計算できる。

この内、右辺の\frac{\partial  L}{\partial  O _ k} は、Conv2からの逆伝播した勾配の各ピクセルである。
また、右辺の\frac{\partial O _ k}{\partial X _ i}は、O _ kの式をX _ i偏微分したものである。
O _ kは次の式で表される。

したがって、これらを例えばX _ 1偏微分し、元の式に代入すると、

[tex: \begin{align} \frac{\partial L}{\partial X _ 1} &= \sum_{k} \frac{\partial L}{\partial O _ k} \cdot \frac{\partial O _ k}{\partial X _ 1} \ &= F _ 1 \cdot O _ 1 + 0 \cdot O _ 2 + 0 \cdot O _ 3 + 0 \cdot O _ 4 \ &= F _ 1 O _ 1 \end{align}

これを\frac{\partial  L}{\partial X _ 9} まで 求めて、shapeを整えたものが、\frac{\partial  L}{\partial X}である。

畳み込み計算で表現

この勾配\frac{\partial  L}{\partial X _ i}を求める過程は、実は畳み込み計算で表現できる。

f:id:meow_memow:20201229153537j:plain

\frac{\partial O}{\partial X _ i}が特定のフィルタor0であり、これが\frac{\partial  L}{\partial  O}の各ピクセルと掛け合わされることから大まかに把握できるかと思う。微分して0になる場所は、入力\frac{\partial  L}{\partial  O}に0をpaddingすることで対応する。
畳み込み計算と1つ違う点としては、フィルタを180°回転させることである(実際の行列計算ではフィルタの行列を転置する操作にあたる[4])。

畳み込みをしながら\frac{\partial  L}{\partial X _ i}を求めていくイメージを下図に示す

\frac{\partial  L}{\partial X _ 1}は、

\frac{\partial  L}{\partial X _ 2}は、

最後までフィルターをスライドすると、\frac{\partial  L}{\partial X _ 9} まで求まる。

このように、入力の特徴マップに対する勾配\frac{\partial  L}{\partial X}を求める上で、転置畳み込みが使用される。

パラメータの謎の解明

ここで冒頭、nn.ConvTranspose2dのパラメータ名と処理の整合性が合っていないと思った話に戻る。

実はnn.ConvTranspose2dの引数のpadding,strideは、Convの順伝播のパラメータを与えなければいけないのである。
この理由は、TransConvの入力としてConvの出力の勾配(例だと\frac{\partial  L}{\partial  O})がくることを想定しているためだと考えられる。
もともとは勾配を求める用途のものを、(Convの文脈とは独立して)単なるアップサンプリング用途に使おうとしているから私は挙動を理解できなかったのである。

パラメータをどう指定するのか

seq2seq型のモデルだったらencoderのConv層のパラメータがあると思うので、decoder側のnn.ConvTranspose2dのパラメータにはそれを与えてやればよい。
しかし、seq2seqモデルではない場合などで、アップサンプリングにnn.ConvTranspose2dを使いたい場合、出力の特徴マップを所要のサイズにするにはパラメータをどう与えるか。

公式ドキュメント[3]によると、サイズは関して下記の計算式で求まると書かれている2
H _ {out} = (H _ {in} −1)×stride[0 −2×padding[0]+dilation[0]×(kernel_size[0]−1)+output_padding[0]+1]

しかし、これを覚えるのは面倒である。

したがって、このH_outの式を覚えるよりも、畳み込みの逆問題を解いた方が早いと私は考える。
例えば、3x3の特徴マップを5x5にアップサンプリングしたい場合、逆問題として5x5の特徴マップを畳み込みで3x3にするにはどうすればいいかを考える。
答えの一例としては、カーネルを2x2、stride=2, padding=0にすればよいので、これをnn.ConvTranspose2dのパラメータに与えるといった具合である。

まとめ

nn.ConvTranspose2dで所要の出力サイズにするためのパラメータ設定方法に関して述べた。

調べ物をする中で、引数と挙動の謎が解けた。文献[1][2]にかかれている通り、畳み込みの逆伝播さえつかめれば転置畳み込みは理解できることがわかった。

逆に考えると、転置畳み込みを理解していないということは、畳み込みも理解していない、ということになる。

参考文献

ふろく: デバッグ用のコード

一応、nn.ConvTranspose2dをかけた時のshapeを確認するためのコードを用意した。


  1. 特徴マップの最小要素を正式には何と呼ぶのかわからなかった。間違っているかもしれないが、本稿ではピクセルと呼ぶことにした。

  2. これはHeightの例であるが、Widthも同様の計算が行われる。