Does ICLR’22 How Does SimSiam Avoid Collapse Without Negative Samples?

Abstract

  • SimSiamがなぜnegative sampleを用いず学習の崩壊を防げているのかの理由はまだ十分に解明されていない

  • まずSimSiamにおける説明的主張を再検討して、それに反論する

  • そして、表現ベクトルを中心成分と残差成分と分解して、predictorを外したシンメトリックなSimSiamは崩壊を防ぐことができず、アシンメトリックにするとextra gradientが出てきて、その中心ベクトルは de-centering効果によって、残差ベクトルはdimensional de-correlation によって崩壊を防ぐ (???)

??? という感じですが、とりあえず先に進んでいきます。

predictorはEOA近似のギャップを埋めているのか?

(SimSiam論文の仮説への反論)

EOAって? Expectation over augmentations のこと (\(\mathbb{E}_{\mathcal{T}}[\cdot]\) )

../_images/how_fig1.png
  • 図1(a)のようにSimSiamのpredictorはstop gradientしない方のencoder側にある

  • SimSiam論文のProof of concept(5.2節) でやったmoving-averageは図1(b)のようになる

  • predictorをEOAと解釈するのは図1(a)というより図1(c)になる。

  • なので、SimSiamのpredictorをEOAとみなすのは無理がある

SimSiam論文では、「\(\mathbb{E}_{\mathcal{T}}[\cdot]\) を計算するのは非現実的だが、\(\mathcal{T}\) が複数のepoch間で暗黙的に分散しているならpredictorが期待値を予測することは可能かも」と言っているが、 本論文は十分に大きい数 \(\mathcal{T}\) からサンプリングして、最新のモデルを通したものの平均をとったほうがより有効でしょうと言っている。

しかし、そうしてしまうとTable 1に示すとおりモデルは崩壊してしまう。 (やっぱりpredictorをEOAとみなすのは無理がある)

../_images/how_tab1.png

Asymmetric interpretation of predictor with stop gradient in SimSiam

話は変わって、どういう構造なら崩壊して、どういう構造なら崩壊しないのか という話になる (なぜ崩壊しないのか?の答えではない)

../_images/how_fig2.png
  • 図2(a) Naive SimSiam: SimSiamからpredictorを除いたものは崩壊する (Table 2)

  • 図2(b) Symmetric Predictor : stop gradientする方にpredictorをつけたものも崩壊する (Table 2)

    • 図2(b)は結局図2(a)の \(f(x)\)\(h(f(x))\) とみなすだけなので、崩壊する

  • 図2(c) Inverse Predictor : 図2(b)のstop gradientする方にpredictorの逆関数をつけると崩壊しない (Figure 3)

pic1   pic2

Inverse predictorなんて用意できるのか?

  • Inverse predictor \(h^{-1}\) も同時に \(P, Z\) の距離を近づけるように学習する

    • 図3が示すように \(h^{-1}\) は学習可能である

    • \(h^{-1}\) は理論的にrandom augmentation \(\mathcal{T}'\) を restore できないので SimSiamにおけるpredictorはEOAではないさらなる証拠だと主張している (よくわからない)

../_images/how_alg5.png

Vector decomposition for understanding collapse

  • \(z\) : representation vector (\(z=f(x)\))

  • \(Z\) : zを正規化したもの (\(Z=z/\|z\|\))

(1)\[\mathcal{L}_{MSE} = (Z_a - Z_b)^2 / 2 = -Z_a \cdot Z_b = L_{cosine}\]
  • \(P\) : normalized output of predictor ( \(P = p / \|p\|\) )

(2)\[\mathcal{L}_{SimSiam} = - (P_a \cdot sg(Z_b) + P_b \cdot sg(Z_a))\]

\(Z\)\(Z = o + r\) と2つのベクトルに分ける

  • center vector \(o\) : Zの期待値 (\(o_z = \mathbb{E}[Z]\) )だが、minibatch内の標本平均で近似する \(o_z = \frac{1}{M}\sum_{m=1}^M Z_m\)

  • residual vector \(r\) : Zの残差成分 (\(r = Z - o_z\))

  • ration of o \(m_o := \| o \| / \|z\|\)

  • ration of r \(m_r := \|r\| / \|z\|\)

崩壊が起こると、\(Z\) は center vector \(o\) に近くなり、\(m_o\) は1に近くなり、\(m_r\) は0 に近くなる。 なので、 \(m_o \gg m_r\) となるとき、崩壊していると解釈する。

推測1 \(Z_a = o_z + r_a\) とすると、\(o_z\) の勾配成分は \(m_o\) を増加させ、\(r_a\) の勾配成分は逆に \(m_r\) を増加させる。

推測1を検証するために、dummy gradient term \(Z_a\) に立ち返る。 \(-Z_a \cdot sg(o_z)\)\(-Z_a \cdot sg(r_a)\) 2種類のロスをデザインして、それぞれ、\(o\)\(r_a\) のgradient componetの影響を調べたのが図4.

../_images/how_fig4.png

gradient component \(o_z\) は \(m_o\) を増加させ、gradient component \(r_a\)\(m_r\) を増加させることがわかる。

Extra gradient component for alleviating collapse

ロス関数の negative gradient

(3)\[- \frac{\partial \mathcal{L}_{MSE}}{\partial Z_a} = Z_b - Z_a \Leftrightarrow -\frac{\partial \mathcal{L}_{cosine}}{\partial Z_a} = Z_b\]

, where the gradient component \(Z_a\) is a dummy term because the loss \(- Z_a \cdot Z_a = -1\) is a constant having zero gradient on the encoder \(f\).

(3) はtwo equivalent formsとして解釈でき \(Z_b - Z_a\) を選ぶと (???)、 \(Z_b - Z_a = (o_z + r_b) - (o_z + r_a) = r_b - r_a\) になる。 \(r_b\)\(r_a\) と同じpositive sampleから来ているので、\(r_b\) も同じく \(m_r\) を増加させることが期待されるが、その効果は \(r_a\) より小さいので (???)、崩壊の原因になる.

よくわかんないポイント

  • その効果は \(r_a\) より小さいのはなぜ ( \(r_b - r_a\) が小さくなるならわからんでもない )

  • \(Z_b - Z_a\) ではなく、\(Z_b\) を選ぶと \(o_z + r_b\) だけど

  • そもそも \(Z_a\) で微分したものを見て何になるのか (微分したものを見るというのは感覚的にわかるが、なんで \(Z_a\) で微分したものでよいのかの議論がほしい)

図2(a)のnegative gradient on \(Z_a\)\(Z_b\), 図2(b)のnegative gradient on \(P_a\)\(P_b\) となる。 \(Z_b, P_b\) を Basic Gradientとする. 上記の解釈から Basic Gradient では崩壊を防ぐことができないので、対称性を壊すために余分な成分を導入する必要がある。その余分な成分を Extra Gradient と呼び、\(G_e\) と表記する.

例えば図2(a)に negative sample を導入することで、negative sampleによる余分な成分( \(G_e\) )があるので崩壊しない。同様にSimSiamの \(P_a\) についての負の勾配 \(Z_b\) を basic gradient \(P_b + G_e\) と導出することも可能である ( \(G_e = Z_b - P_b\) )。 ???

( どういうことなのか??? 言葉遊びしてるだけに見える。)

どの成分が崩壊を防いでいるのか?

\(G_e\)\(Z_a\) と同じように center vectorと residual vectorに分解する ( \(G_e = o_e + r_e\) ) \(G_e, o_e, r_e\) どの成分が崩壊を防いでるのか、triplet loss \(\mathcal{L}_{tri} = - Z_a \cdot sg(Z_b - Z_n)\) , ( \(Z_n\) はnegative sampleの表現ベクトル ) で実験してみる。\(Z_a\) についてのnegative gradientは \(Z_b - Z_n\)\(Z_b\) が basic gradientなので、 \(G_e = - Z_n\) となる。

表3に \(Ge\) の代わりに \(o_e, r_e\) を入れてみたときの学習結果を示す。 \(o_e\) が崩壊を防いでいるのがわかり、\(r_e\) だけでは崩壊して、 \(r_e\) を保持している \(G_e\) だと精度が低下する。 negative sampleはランダムに選ばれるため、 \(r_e\) は最適化においてランダムノイズのように振る舞い、性能を低下させる (???? そうなの???)

../_images/how_tab3.png

SimSiamの \(P_a\) における負の勾配は以下

\[-\frac{\partial \mathcal{L}_{SimSiam}}{\partial P_a} = Z_b = P_b + (Z_b - P_b) = P_b + G_e\]

triplet lossでやったような実験をすると、表4になる

../_images/how_tab4.png

予想どおり \(G_e\) を取り除くと崩壊して、\(o_e, r_e\) の両成分を残すと最高の性能を発揮する (そうなの???) 興味深いことに \(o_e, r_e\) のどちらかを残せば崩壊しない。 推論1にもどついて、 \(o_e\) がどのような影響を与えるのか分析する。

SimSiamで o_e がどのように崩壊を防いでいるのか

\(G_e = Z_b - P_b\) なので、\(o_e = o_z - o_p\) である。 推論1によると( \(P_a\) についてみているので) 負の \(o_p\) は崩壊を防ぐので \(o_e\) は崩壊を防ぐ.

\(o_p\) がどれくらい \(o_e\) に影響しているかをみるために、 図5に、スカラー \(\eta_p\) を動かしたときの \(o_e - \eta_p o_p\)\(o_p\) のコサイン類似度を示す。 \(\eta_p\)\(-0.5\) くらいのとき、コサイン類似度が0になるので、\(o_e \approx -0.5 o_p\) になる.

(直角になってもコサイン類似度0じゃないのか?と思ったけど、 \(\eta_p\) はスカラーなので、直角になることはないから コサイン類似度が0ということは、 \(o_e - \eta_p o_p\) が0になるしかにということ?)

というわけで、 \(o_e\) がSimSiamが崩壊するのを防ぐ

../_images/how_fig5.png

triplet lossの実験では \(r_e\) を維持すると精度下がったが、SimSiamでは \(r_e\) だけでも崩壊を防いで \(G_e\) と同等の性能を達成している。これは次で説明する dimensional de-correlation の観点から説明できる。

推論2 dimensional de-correlation が \(m_r\) を増加させると考える。動機は単純で、次元の相関が最小になるのは、1つの次元だけが個々のクラスに対して非常に高い値を持ち、異なるクラスに対して次元が変化する場合であるから.

上記の推測を検証するため、SimSiamの損失 (2) で訓練し、式 (1) の損失で意図的に \(m_r\) を0に近づけて数エポック訓練する。次に、付録A.6で詳述する相関正則化項のみを用いて損失を学習させる。図5(b)の結果から、この正則化項が非常に速い速度で \(m_r\) を増加させていることがわかる。

../_images/how_alg6.png

\(h\)\(o_e\) の影響を排除するためにFC層を一つしか持たないと仮定すると、FCの重みはエンコーダ出力の異なる次元間の相関を学習することが期待される。本質的に、 \(h\)\(h(Z_a)\)\(I(Z_b)\) の間のコサイン類似度を最小化するように学習される(Iは同一性写像)。したがって、相関を学習する \(h\)\(I\) に近づくように最適化され、これは概念的には \(Z\) に対して脱相関を目標に最適化することと等しい。表4に示すように、SimSiamでは \(r_e\) のみでも崩壊を防ぐことができ、 \(r_e\) には脱中心化の効果がないため、脱相関効果に起因するものであると考えられる。図6からは最初の数エポックを除き、SimSiamは共分散を減少させている。

(何言ってるかわからない・・・)

../_images/how_fig6.png