meowの覚え書き

write to think, create to understand

【Council GAN】Im2Im論文『Breaking the cycle - Colleagues are all you need』 を 顔写真 → アニメ顔変換タスクを中心に理解する

f:id:meow_memow:20200709081635j:plain:w500

2ドメイン間で画像の対応が不要(unpaired)な GAN ベースの画像変換(Image to Image translation; im2im)手法である Council GAN について見ていく。手法の試し方も一応載せておく。 なお、私の関心は顔写真 → アニメ顔変換タスク(selfie2anime)のみなので、その他のタスク(メガネ除去、男性-女性顔変換)は扱わない。


スポンサードリンク

論文に関して

CVPR2020 に採択された論文である。
2019 年 11 月にプレプリント版が上がっていたが、2020年6月ごろになってプロジェクトページおよびリポジトリ(github)が公開された1

提案手法の要旨

f:id:meow_memow:20200709082021j:plain GAN全体系のLossにCouncil lossを追加しているところが新規性のポイントである。もう少し補足すると、Cycle Consistency Loss の代わりにCouncil loss(後述)を使った所が真新しい。

ちなみに、第1項目のGAN_lossはLSGANで使われているlossである。すなわち、識別器側は本物と生成器が生成した偽物を見分けるように学習し、生成器側は識別器を本物だと騙すように学習する時のlossである2

Council Loss とは

f:id:meow_memow:20200709082220j:plain:w500

counsilというユニットを複数用意し、それぞれのユニットが相互作用を与えるようにユニット内で loss を計算する。
このcounsilは1ユニット3つの要素からなる。画像生成器  {G_i} 、画像の真贋の識別器 D_i、そして、画像(真贋問わず)が自分のcounsilから生成されたものか否かを識別する識別器 \hat{D_i}である。なお、添字の  i i番目のcounsilにおける要素であることを表している。

council lossの定義は下記。

f:id:meow_memow:20200709082911j:plain

\hat{D_i} はGAN全体系のloss  Vを 上げたいので、自分が所属する i番目のcouncilの生成器  {G_i} の出力画像に対する事後確率は0で、他のcouncilが生成した画像に対しては1に近い値を出すように学習する。すなわち、自分のcounsilが生成した画像をfake、他のconsilが生成した画像をrealと判定したい。(なんか直感と反するけど)。
一方で生成器 G_i側は\hat{D_i}を騙したい。すなわち、自分の生成した画像は同じcouncilから生成したものと見なされたいし、他のcounsilには違うと見なされたいので、他の( jと置く) counsil の生成器  G_jが持たないような特徴を生成するように学習する。
これが Counsil loss の役割のようだ。

Council loss のよい点

p.1,col.2,l.20〜にそれっぽいことが書いてある。

  • 各 council ごとの \hat{D_i}の集合的な意見を活用することで、 G_iがより安定し多様に富んだドメイン変換ができる
    • 生成される画像は、counsil 間で共通の特徴があるべき
    • \hat{D_i}によって、 G_iは他の counsil が見分けられるような特徴を生成するように収束する。これはソースとターゲットドメイン間の相互情報量の最大化をしている。
      • これにより、生成された画像がソース側の重要な特徴を維持する

事実なのか意見なのかはわからないが、このような狙いがあるらしい。

どういうことか考えてみた。例えば、council が 3 つあるとして、 \hat{D_1}を騙すために、 G_2 G_3は共通の特徴を持ちうるということ。 \hat{D_1}は髪の毛の部分を見て自分の counsil かどうかを判定しているとすれば、G_2 G_3は髪の毛の部分は似たような画像変換を獲得する、でも \hat{D_2}も騙したいのでG_2 G_3は髪の毛以外の所で差別化を図らざるを得なくなり、最終的にそれぞれの counsil は共通の特徴を持ちつつもそれぞれ異なる画像変換を獲得する的なイメージだろうか。

また、 \hat{D_i}には、ソースドメインの画像も直接入力している。この理由はソースドメインの特徴も学んでほしいと願っている(wish)からとのこと(p.4 Fig.4の説明より)。

生成器G_iの補足

生成器側は、Encoder-Decoderアーキテクチャだが、デコードフェーズ時に、ノイズ z_iを混ぜている。このことで、多様な画像を生成している(Fig.3)。

f:id:meow_memow:20200709083956j:plain:h250

先行手法Cycle Consistency Loss の欠点

これまでの、unpaired な image to image translation を GAN ベースで実現する研究は Cycle GAN で提案された Cycle Consistency Loss が使われていた。一方、Council GAN は以上のような仕組みで Cycle Consistency Loss を使わずに im2im できるようになった。論文タイトルの「Breaking the cycle」の由来である。

じゃあ何でCycle Consistency Loss がよくないとしているのか理由も調べた。

主に論文 p.2,col.2,l.26〜らへん

  • Cycle Consistency Loss は、不要な制約をかけうる。
    • cycle して元に戻そうとする余り、ターゲットの画像にソースドメイン側の特徴を残してしまう(p.1,col2,l.12)
    • 変化の量を制限してしまう(p.3,col.1,l.11)
  • 隠れた情報が残ってしまうことを回避できる
    • ある研究3によると、Cycle GANはターゲットドメインへの変換時にソースドメインの成分を高周波成分に埋め込んでいるので、Adversal Attackの脆弱性があるとのこと。これにより入力画像に細工をすることで任意の変換ができてしまうとのこと。
  • (Cycle GAN に関わらず、既存の画像変換手法は変換結果に多様性がない)

[脱線]そもそもCycle Consistency Loss はunpairedなim2imとどう関係があるのか

個人的な興味から論文の本筋とちょっと脱線する。そもそも、どういう理屈があってCycle Consistency Lossでunpairedなim2imが実現できるのか気になったのでCycle GAN元論文を読んで調べてみた。

p.4,"3.2. Cycle Consistency Loss"の節

Adversarial training can, in theory, learn mappings G and F that produce outputs identically distributed as target domains Y and X respectively (strictly speaking, this requires G and F to be stochastic functions) [15]. However, with large enough capacity, a network can map the same set of input images to any random permutation of images in the target domain, where any of the learned mappings can induce an output distribution that matches the target distribution. Thus, adversarial losses alone cannot guarantee that the learned function can map an individual input xi to a desired output yi . To further reduce the space of possible mapping functions, we argue that the learned mapping functions should be cycle-consistent:

読んでも私には理解できる力がなかった。特にrandom permutationのあたり。
とりあえず部分的に理解すると、「Adversarial trainingは、理論上はXとYのドメイン間の写像を学習できるが、ネットワークに変換能力が十分にある場合、adversarial lossesだけでは、変換候補先がたくさんありすぎてXからYへ移す保証ができない。そこで、写像が移しうる候補を制限するために、Cycle Consistency Lossを入れるべきだと我々は考えた」的な感じか。
大事なのは、Unpaired im2imに必ずしもCycle Consistency Loss がある必要はないということ。写像がちゃんと獲得できるならば lossは何でもいい。Council Lossはそれを実証している。

実験で手法の比較

councilの数は4。4つのcouncilの生成器から出力された結果の一例が示されている。

f:id:meow_memow:20200709084558j:plain

これに対する著者の定性的な評価(p.7,col.1,l.20〜)としては、

  • 表情や顔の構造(つまり、顎の形状)が入力画像によく似ている。これは、councilのメンバーが入力画像中の特徴の方が"同意"しやすいためだと考えられる
    • ※ "同意"の意味がわからなかったが、GAN 学習時、生成器はソースドメイン側の特徴を持っていれば識別器を騙しやすいということだろうか。そうだとすると当たり前のことを言っている気がする。
  • selfie2anime は難しいタスクである。ドメイン間のスタイルが異なるだけでなく、入力と出力で幾何学的構造も大きく異なる。例えば目のサイズ。このせいで、構造の不一致や歪み、アーティファクトを発生させることにつながる。

なお、プロジェクトページには、selfie2animeテストデータ全100件に対する各im2im手法の生成結果を載せた補足資料が公開されている。GAN論文はいい変換例だけ載せる疑惑があるので、こういうのは助かる。ただ、この資料を見るに、論文に載せている例は変換がうまくいったものを図に採用している感は否めない。

また、定量評価(FID, KID)では先行手法よりも良い結果だったとのこと。

f:id:meow_memow:20200709084736j:plain

何番目のcouncilを使ったのかは不明。 見た目的にU-GAT-ITと差がない印象であるが、定量的には差が出ている模様。

Council GAN を手元で試してみる

公式PyTorch実装がgithubで公開されているので試す。
https://github.com/Onr/Council-GAN

実行環境のセットアップ

Anacondaでの実行を想定しているらしくconda_requirements.ymlが用意されている。requirements.txtは用意されていない。なので私は何回かコードを実行して不足しているパッケージを都度入れた。

ちなみに、Dockerfileも用意されている。公式の現在(2020/07/04)のバージョンのDockerfileはビルドが通らないので、有志がプルリクしているDockerfileを使えばビルドできる

データのダウンロード

  • selfie2animeデータセットのダウンロード
    • $ bash ./scripts/download.sh U_GAT_IT_selfie2anime
      • 先行研究の UGATITで使用されていたものを落としてきている。
      • 画像の拡張子はjpgだが、実態のフォーマットはpngである。処理に問題はない。
    • datasets/selfie2animeに落とされる。
    • データサイズはソース、ターゲットドメインそれぞれ、
      • train: 3,400 枚
      • test: 100 枚
    • あと、データは全員女性4なので、このデータで学習したモデルが男性の画像で正常に変換できる保証はない。
  • 学習済みモデルのダウンロード
    • $ bash ./scripts/download.sh pretrain_selfie_to_anime
    • ./pretrain/anime/に4つのgeneratorが落とされる。

実行(コマンドライン)

  • モデル学習
    • $ python train.py --config configs/anime2face_council_folder.yaml --output_path ./outputs/council_anime2face_256_256 --resume
      • 使用GPUメモリはデフォルトで4000MBくらい
      • チェックポイントは10000stepごとに取られる。
      • チェックポイントがある状態でコマンドを実行すると、最後のチェックポイントから学習が再開する
        • ちなみに、前述の学習済みモデルには、識別器とoptimizerのチェックポイントがないので、学習済みモデルで学習を再開することはできない。
    • モデルのファイル名のフォーマットは<ドメインの方向>_<モデルの役割>_<counsilの番号>_<学習で回したステップ数>.pt
      • 例えばb2a_gen_0_01000000.ptは、ドメインBからA、Generator、0番目のcousil、1,000,000step学習したモデルということ
  • 学習済みモデルで評価
    • $ python test_on_folder.py --config pretrain/anime/256/anime2face_council_folder.yaml --output_folder outputs/council_anime2face_256_256 --checkpoint pretrain/anime/256/01000000 --input_folder ./datasets/selfie2anime/testB --a2b 0
    • outputs/test_resに画像ディレクトリができる
      • デフォルトだと10つできる。
        • 4つのcouncilのうちランダムにgeneratorを選んだ後、decodeフェーズでノイズを混ぜて画像生成している。それを10回繰り返している
        • ノイズはランダムに毎回生成しているため、画像は毎回実行結果が異なる。
    • 自分の用意した顔写真で試したい場合は、datasets/selfie2anime/testBに顔写真を入れればよい。

実行(GUI)

PyQtベースのGUIも用意されていたが、私の環境では動かなかった。 こんなエラーが出る。

$ python test_gui.py --config pretrain/anime/256/anime2face_council_folder.yaml --checkpoint pretrain/anime/256/b2a_gen_3_01000000.pt --a2b 0

test_gui.py:56: UserWarning: Filed to import face_recognition, setting use_face_locations to FALSE
warnings.warn("Filed to import face_recognition, setting use_face_locations to FALSE")
qt.qpa.xcb: could not connect to display
qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "" even though it was found.
This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.

Available platform plugins are: eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx, webgl, xcb.

zsh: abort (core dumped) python test_gui.py --config pretrain/anime/256/anime2face_council_folder.yaml

lennaさん画像で変換を試して比較してみる

f:id:meow_memow:20200709085014j:plain

f:id:meow_memow:20200709085028j:plain

左上が元画像(ちょっとクロップしている)、上真中が私がlennaさん用にガッチガチにチューニング5したCycleGANの結果。 右上がU-GAT-IT6。形状のデフォルメはされていない。attentionをいれているので形状に注目するはずなのだが。
下がCouncil GAN 10パターン。良いものが出るまで生成ガチャを回した。かなりアニメっぽい。とくに1行5列目のlennaさんはアニメっぽくて気に入っている。

所感

いい画像が生成できるかは確率次第だったので、平均的な性能はU-GAT-ITとあまり変わらないと思う。
タイトルで"all you need" と言っているが、selfie2animeタスクに関しては必要な要素はまだまだありそう。


  1. CVPR2020の開催時期に合わせたものと思われる。

  2. LSGAN の識別器,生成器の学習のコードはこんな感じで実装する。

  3. CycleGAN, a Master of Steganography, https://arxiv.org/abs/1712.02950

  4. このご時世、見た目だけで"女性"というのはポリコレ事案な気がするが、どう言えばいいのかわからないので、とりあえず女性と言っておく。

  5. 詳しくはこちらのスライドで説明: CycleGANで顔写真をアニメ調に変換する

  6. U-GAT-ITの私の解説もよろしければご参照いただければと思います。: 【論文紹介】U-GAT-IT