Grover ICLR’19 Stochastic Optimization of Sorting Networks via Continuous Relaxations

https://openreview.net/forum?id=H1eSS3CcKX

著者

  • Aditya Grover (Computer Science Department Stanford University)

  • Eric Wang (Computer Science Department Stanford University)

  • Aaron Zweig (Computer Science Department Stanford University)

  • Stefano Ermon (Computer Science Department Stanford University)

概要

  • sortをargmaxを用いて記述 → argmaxをsoftmaxで代替して微分可能にする

  • 順列の確率がPlackett-Luce分布に従うとした設定のversionも提案

前置き

Permutation matrixの定義

まずPermutation matrixを導入します。

  • \(\mathbf{z} = [z_1, \ldots, z_n]^\top\) を長さnのユニークインデックス \(\{1, 2, \ldots, n\}\) のリストとする (zの取りうる集合を \(\mathcal{Z}_n\) とする)

  • そして \(\mathbf{z}\) のpermutation matrix \(P_{\mathbf{z}} \in \{0, 1\}^{n \times n}\) を以下のように定義する (\({P_{\mathbf{z}} [i, j]}\)\((i, j)成分\) )

\begin{align} P_{\mathbf{z}}[i, j] := \left\{ \begin{array}{ll} 1 & j = z_i \\ 0 & \text{otherwise}. \end{array} \right. \end{align}

具体例

  • \(\mathbf{z} = [1, 3, 4, 2]^\top\) とすると \(\mathbf{z}\) のPermutation Maxtrixは以下になる。

\begin{align} P_{\mathbf{z}} = \begin{pmatrix} 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 1 & 0 & 0 \end{pmatrix} \end{align}

sort関数の定義

そして、\(\mathbf{s} = [s_1, \ldots, s_n]^\top\) を長さ \(n\) の実数のベクトル、 \([i]\)\(\mathbf{s}\) の中で \(i\) 番目に大きい値のindexとして、 sort関数 \(\mathbb{R}^n \to \mathcal{Z}_n\) を次のように定義します。

\begin{align} \text{sort}(\mathbf{s}) := [[1], [2], \ldots, [n]] \end{align}

そうすると ベクトル \(\mathbf{s}\) を(降順に)ソートしたベクトルは \(P_{\text{sort}(\mathbf{s})} \mathbf{s}\) と書けるようになる。


具体例

  • \(\mathbf{s} = [9, 1, 5, 3]^\top\) とすると

  • 1番大きい9のindexは1, 2番目に大きい5のindexは3、・・・なので \(\text{sort}(\mathbf{s}) = [1, 3, 4, 2]^\top\) になる

  • そして、\(\text{sort}(\mathbf{s})\) のPermutation matrixは定義より以下となる

\begin{align} P_{\text{sort}(\mathbf{s})} = \begin{pmatrix} 1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 1 & 0 & 0 \end{pmatrix} \end{align}
  • そして \(P_{\text{sort}(\mathbf{s})} \mathbf{s} = [9, 5, 3, 1]\) となるので、ちゃんとソートされている。


Corollary 3 (sortを数学的に記述する)

  • \(\mathbf{s}\) の要素間の差の行列を \(A_{\mathbf{s}}\) とする (つまり \(A_{\mathbf{s}}[i,j] := |s_i - s_j |\))

  • すると、\(\text{sort}(\mathbf{s})\) のPermutation matrix \(P_{\text{sort}(\mathbf{s})}\) は以下になる

(1)\[\begin{split}\begin{align} P_{\text{sort}(\mathbf{s})} [i, j] = \left\{ \begin{array}{ll} 1 & j = \arg \max [(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} ]\\ 0 & \text{otherwise}. \end{array} \right. \end{align}\end{split}\]
  • (\(\mathbb{1}\) は要素がすべて1のcolumn vector)


証明: Proof of Corollary 3


具体例

  • \(\mathbf{s} = [9, 1, 5, 3]^\top\) とすると \(A_{\mathbf{s}}\) は以下になる

\begin{align} A_{\mathbf{s}} = \begin{pmatrix} 0 & 8 & 4 & 6 \\ 8 & 0 & 4 & 2 \\ 4 & 4 & 0 & 2 \\ 6 & 2 & 2 & 0 \end{pmatrix} \end{align}
  • \(i = 1\) のとき、以下なので \(\arg \max [(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} ] = \arg \max[9, -11, 5, -1] = 1\)

\[(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} = 3s - A_{\mathbf{s}} \mathbb{1} = [27, 3, 15, 9] - [18, 14, 10, 10] = [9, -11, 5, -1]\]
  • \(i = 2\) のとき、以下なので \(\arg \max [(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} ] =\arg \max [-9, -13, -5, -9] = 3\)

\[(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} = s - A_{\mathbf{s}} \mathbb{1} = [9, 1, 5, 3] - [18, 14, 10, 10] = [-9, -13, -5, -9]\]
  • \(i = 3\) のとき、以下なので \(\arg \max [(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} ] = \arg \max [-27, -15, -15, -13] = 4\)

\[(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} = -s - A_{\mathbf{s}} \mathbb{1} = [-9, -1, -5, -3] - [18, 14, 10, 10] = [-27, -15, -15, -13]\]
  • \(i = 4\) のとき、以下なので \(\arg \max [(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} ] = \arg \max [-45, -17, -25, -19] = 2\)

\[(n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1} = -3s - A_{\mathbf{s}} \mathbb{1} = [-27, -3, -15, -9] - [18, 14, 10, 10] = [-45, -17, -25, -19]\]
  • \(\text{sort}(\mathbf{s}) = [1, 3, 4, 2]^\top\) なのであっている。

提案法: NeuralSort

モチベーション

次のようなsortを含む目的関数をgradient-based methodで最適化することが目標 (\(\theta\) がモデルパラメータで \(\mathbf{s}\)\(\theta\) に依存)

\[L(\theta, \mathbf{s}) = f(P_{\mathbf{z}}; \theta), ~~~ \text{where} ~~ \mathbf{z} = \text{sort}(\mathbf{s}).\]

提案法

argmaxは微分できないので、softmaxで置換して、sortのPermutation matrixを以下のようにrelaxationする。 (\(\tau\) は温度パラメータ)

(2)\[\hat{P}_{\text{sort}(\mathbf{s})} [i, :] (\tau) = \text{softmax} [ ((n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1}) / \tau ]\]

tensorflowで実装すると次のような感じ(batch版)

import tensorflow as tf

def neural_sort(s, tau=0.1):
    m = s.shape[0]
    n = s.shape[1]
    es = tf.repeat(tf.expand_dims(s, axis=1), n, axis=1) 
    A = tf.abs(tf.transpose(es, perm=[0,2,1]) - es)
    i = tf.range(1, n+1, dtype=tf.float32)
    A1 = tf.repeat(tf.expand_dims(tf.reduce_sum(A, axis=1), axis=1), n, axis=1)
    pi = tf.multiply(n+1-2*i[:, None], es) - A1
    pm = tf.nn.softmax(pi/tau)

    sort_s = tf.linalg.matvec(pm, s)

    b_hat = tf.expand_dims(tf.repeat(i[None, :], m, axis=0), axis=2)
    rank_s = tf.linalg.matmul(pm, b_hat)[...,0]
    return sort_s, rank_s
s = tf.constant([[9,1,5,3,], [-6.5, 0.4, 1.5, 3.8,]])
ss, sr = neural_sort(s)

print("s =\n", s.numpy())
print("sort(s) =\n", ss.numpy())
print("rank(s) =\n", sr.numpy())
s =
 [[ 9.   1.   5.   3. ]
 [-6.5  0.4  1.5  3.8]]
sort(s) =
 [[ 9.         5.         3.         1.       ]
 [ 3.8        1.4999816  0.4000184 -6.5      ]]
rank(s) =
 [[1.        3.        4.        2.       ]
 [4.        2.9999833 2.0000167 1.       ]]

Theorem 4

  • Limiting behavior: \(\mathbf{s}\) の各要素が \(\mathbb{R}\) のルベーグ測度に対して絶対連続な分布から独立に引かれていると仮定すると、以下が成り立つ

(3)\[\lim_{\tau \rightarrow 0^+} \hat{P}_{\text{sort}(\mathbf{s})} [i, :] (\tau) = {P}_{\text{sort}(\mathbf{s})} [i, :] ~~~ \forall i \in \{1,2,\ldots, n\}.\]
  • Unimodality: \(\forall \tau > 0\) において \(\hat{P}_{\text{sort}(\mathbf{s})}\) は unimodal row stochastic matrixである。さらに、 \(\hat{P}_{\text{sort}(\mathbf{s})}\) の各行にargmaxを取ったベクトルは \(\text{sort}(\mathbf{s})\) に一致する。


Limiting behaviorについて

  • \(\text{hardmax}(\mathbf{s})\)\(i := \arg \max(\mathbf{s})\) としてindex iだけが1で他のindexが0のベクトルを返す関数だとする

  • \({P}_{\text{sort}(\mathbf{s})} [i, :] = \text{hardmax}((n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1})\) とかける

  • \(\text{hardmax}(\mathbf{s}) = \arg \max_{\mathbf{x} \in \Delta^n } \langle \mathbf{x},\mathbf{s} \rangle\) とも書ける

  • 一方softmaxは次のように書ける (ラグランジュの未定乗数法を使った softmaxの変形の証明)

\[\text{softmax}(\mathbf{s}/\tau) = \arg \max_{\mathbf{x} \in \Delta^n } \left[\langle \mathbf{x},\mathbf{s} \rangle - \tau \sum_{i=1}^n x_i \log x_i \right]\]
  • なので \(\lim_{\tau \rightarrow 0^+} \text{softmax}(\mathbf{s}/\tau) = \arg \max_{\mathbf{x} \in \Delta^n } \langle \mathbf{x},\mathbf{s} \rangle = \text{hardmax}(\mathbf{s})\) なので

  • \(\lim_{\tau \rightarrow 0^+} \hat{P}_{\text{sort}(\mathbf{s})} [i, :] (\tau) = {P}_{\text{sort}(\mathbf{s})} [i, :] ~~~ \forall i \in \{1,2,\ldots, n\}\)


Unimodalityについて

まず定義します。

Definition 1 (Unimodal Row Stochastic Matrices): 以下の3つを満たす行列のこと

  1. Non-negativity: \(U[i,j] \geq 0 ~~~ \forall i, j \in \{1,2,\ldots, n\}.\)

  2. Row Affinity: \(\sum_{j=1}^n U[i,j] = 1 ~~~ \forall i, j \in \{1,2,\ldots, n\}.\)

  3. Argmax Permutation: \(u_i = \arg \max_j U[i, j]\) とすると \(\mathbf{u} \in \mathcal{Z}_n\) であること (つまり \(\mathbf{u}\) がvalid permutationであること)

証明

  • Non-negativity, Row Affinity はsoftmaxの定義より成り立つ。

  • Argmax Permutationは以下がなりたち、 \(\mathbf{u} = \text{sort}(\mathbf{s})\) なので成り立つ。

\begin{align} u_i &= \arg \max [\text{softmax} ((n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1}) ] \\ &= \arg \max [ ((n+1-2i) \mathbf{s} - A_{\mathbf{s}} \mathbb{1}) ] ~~~~ (\text{softmaxのmonotonicityより}) \\ &= [i] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ (\text{Corollary3より}) \end{align}

Unimodalityがあるとなにが嬉しいのか?

  • 特に論文中に説明はなかったが、\(P_{\text{sort}(\mathbf{s})}\) が持っている性質なので \(\hat{P}_{\text{sort}(\mathbf{s})}\) も同じ性質を持っているからいいよねって感じだと思いました

(疑問)

  • 実践的にはsoftmaxよりsoftargmaxのほうがいいんじゃないかと思ったが、どうなのだろうか

  • DCGに適用した場合、ApproxNDCG (Approx NDCGについて詳しく) と結果的にはあんまりかわらないんじゃないかという気がしてきた

Stochastic Optimization over Permutations

今までは

\[L(\theta, \mathbf{s}) = f(P_{\mathbf{z}}; \theta)\]

今度は以下を考える (\(q(\cdot)\)\(\mathcal{Z}_n\) の要素上の分布)

(4)\[L(\theta, \mathbf{s}) = \mathop{\mathbb{E}}_{q(z|s)} [f(P_{\mathbf{z}}; \theta)]\]
../_images/ltr_ns_fig3.png
  • MCMCで \(\theta\) に関する勾配の不偏推定量は得られるが、sampling distributionが \(\mathbf{s}\) に依存するため、\(\mathbf{s}\) に関する勾配の推定量は直接求められない。

  • REINFORCE gradient estimator (Glynn, 1990; Williams, 1992; Fu, 2006) は \(\nabla_s q(\mathbf{z}|\mathbf{s}) = q(\mathbf{z}|\mathbf{s}) \nabla_s \log q(\mathbf{z}|\mathbf{s})\) を使って、 MC gradient estimationは

\[\nabla_s L(\theta, \mathbf{s}) = \mathop{\mathbb{E}}_{q(z|s)} [f(P_{\mathbf{z}}; \theta) \nabla_s \log q(\mathbf{z}|\mathbf{s})] + \mathop{\mathbb{E}}_{q(z|s)} [\nabla_s f(P_{\mathbf{z}}; \theta)]\]
  • だが、REINFORCE gradient estimator は High Varianceに苦しんでいる (Schulman et al., 2015; Glasserman, 2013).


reparameterization trickを使うと、 \(\mathbf{g}\) をgumbel noiseとして 式 (4) は以下になる

\[L(\theta, \mathbf{s}) = \mathop{\mathbb{E}}_{\mathbf{g}} [f(P_{\text{sort}(\log \mathbf{s} + \mathbf{g})}; \theta)]\]

となり、argmaxをsoftmaxで置換して緩和したsortにして、微分したものは

\[\nabla_s \hat{L} (\theta, \mathbf{s}) = \mathop{\mathbb{E}}_{\mathbf{g}} [ \nabla_s f(\hat{P}_{\text{sort}(\log \mathbf{s} + \mathbf{g})}; \theta)]\]

これは、期待値がsに依存しない分布に関するものであるため、モンテカルロ法で効率的に推定できる。

(実験をみると Stochasticにしても性能が上がるわけではない)

Proof of Corollary 3

まず Lemma 2 [Lemma 1 in Ogryczak & Tamir (2003)]

(5)\[\sum_{i}^k s_{[i]} = \min_{\lambda \in \{ s_1, \ldots, s_n \}} \lambda k + \sum_{i=1}^n \max(s_i - \lambda, 0).\]

(書きかけ)

softmaxの変形の証明

\[\text{softmax}(\mathbf{s}/\tau) = \arg \max_{\mathbf{x} \in \Delta^n } \left[\langle \mathbf{x},\mathbf{s} \rangle - \tau \sum_{i=1}^n x_i \log x_i \right]\]

ラグランジュの未定乗数法を使うと、ラグランジュ関数は

\[\mathcal{L}(\mathbf{x}, \mu, \mathbf{\lambda}) = \langle \mathbf{x},\mathbf{s} \rangle - \tau \sum_{i=1}^n x_i \log x_i + \mu \left(\sum_{i=1}^n x_i - 1 \right) - \sum_{i=1}^n \lambda_i x_i.\]

(第三項は足して1の制約から、第四項はすべての要素が非負の制約から)

スレーター制約を満たすのでKKT条件が最適性の必要十分条件になって、KKT条件は以下。

\begin{align} \cfrac{\partial \mathcal{L}}{\partial x_i} &= s_i - \tau \log x_i - \tau + \mu - \lambda_i = 0, ~~~ \forall i \\ \mu \left( \sum_{i} x_i -1 \right) &= 0, \\ \lambda_i x_i &= 0, \\ x_i &\ge 0, \end{align}

1つ目の条件から

(6)\[x_i = \exp \left(\frac{s_i - \tau + \mu - \lambda_i }{\tau} \right)\]

となるので \(x_i > 0\) になり、また \(\lambda_i x_i = 0\) なので \(\lambda_i = 0\) になる。

2つ目の条件と式 (6) から

(7)\[\begin{split}\begin{aligned} & \sum_{i} x_i = 1 \Leftrightarrow \sum_{i} \exp \left(\frac{s_i - \tau + \mu }{\tau} \right) = 1 \notag \\ \Leftrightarrow & \sum_{i} \exp \left(\frac{s_i}{\tau} \right) \exp(-1) \exp \left(\frac{ \mu }{\tau} \right) = 1 \Leftrightarrow \exp \left(\frac{ \mu }{\tau} \right) = \cfrac{e}{\sum_{i} \exp \left(\frac{s_i}{\tau} \right)} \end{aligned}\end{split}\]

また、式 (6) から

(8)\[x_i = \exp \left(\frac{s_i}{\tau} \right) \exp(-1) \exp\left(\frac{\mu}{\tau} \right) \Leftrightarrow x_i e \exp \left(- \frac{s_i}{\tau} \right) = \exp\left(\frac{\mu}{\tau} \right)\]

(7)(8) から

\[x_i e \exp \left(- \frac{s_i}{\tau} \right) = \cfrac{e}{\sum_{i} \exp \left(\frac{s_i}{\tau} \right)} \Leftrightarrow x_i = \cfrac{\exp \left(\frac{s_i}{\tau} \right)}{\sum_{i} \exp \left(\frac{s_i}{\tau} \right)}\]

よって、以下になる。

\[\arg \max_{\mathbf{x} \in \Delta^n } \left[\langle \mathbf{x},\mathbf{s} \rangle - \tau \sum_{i=1}^n x_i \log x_i \right] = \cfrac{\exp \left(\frac{\mathbf{s}}{\tau} \right)}{\sum_{i} \exp \left(\frac{\mathbf{s}}{\tau} \right)} = \text{softmax}(\mathbf{s}/\tau)\]