Skip to content

Optimal Transport

Let us say we need to hold a party, we have \(n=8\) people and \(m=5\) kinds of snacks. Let \(\mathbf{r} = (3,3,3,4,2,2,2,1)\) with dimension \(n\) be the vector containing the amount of dessert every person can eat. Similarly, \(\mathbf{c}=(4, 2, 6, 4, 4)^\intercal\) with dimension \(m\) denotes the vector of how much there is of every dessert.

Often \(\mathbf{r}\) and \(\mathbf{c}\) represent marginal probability distributions, hence their values sum to one.

snacks_people

snacks

We also get people's preference:

snacks_pref

Optimal transport formulation

Our goal is to solve a matrix \(P\in \mathbb{R}_{>0}^{n\times m}\), where \(P_{ij}\) means the \(i\)-th people should be given \(j\)-th snack.

Let \(U(\mathbf{r}, \mathbf{c})\) be the set of positive \(n\times m\) matrices for which the columns sum to \(\mathbf{r}\) and the rows sum to \(\mathbf{c}\):

\[ U(\mathbf{r}, \mathbf{c}) = \{P\in \mathbb{R}_{>0}^{n\times m}\mid P\mathbf{1}_m = \mathbf{r}, P^\intercal\mathbf{1}_n = \mathbf{c}\}\,. \]

Thus, \(U(\mathbf{r}, \mathbf{c})\) contains all potential solutions. The preferences of each person are also stored in a matrix, i.e., cost matrix \(M\in \mathbb{R}^{n\times m}\) (We can change our preference matrix into a cost matrix by flipping the sign of every element)

So the problem we want to solve is formally posed as

\[ d_M(\mathbf{r}, \mathbf{c}) = \min_{P\in U(\mathbf{r}, \mathbf{c})}\, \sum_{i,j}P_{ij}M_{ij}\,. \]

This is called the optimal transport between \(\mathbf{r}\) and \(\mathbf{c}\). It can be solved relatively easily using linear programming.

The optimum, \(d_M(\mathbf{r}, \mathbf{c})\), is the Wasserstein metric or Wasserstein distance, which is a distance between two probability distributions. It is sometimes also called the earth mover distance as it can be interpreted as how much dirt you have to move to change one landscape distribution in another.

Sinkhorn distance

Consider a slightly modified form of optimal transport:

\[ d_M^\lambda(\mathbf{r}, \mathbf{c}) = \min_{P\in U(\mathbf{r}, \mathbf{c})}\, \sum_{i,j}P_{ij}M_{ij} - \frac{1}{\lambda}h(P) \]

where the regularization term is the information entropy of \(P\):

\[ h(P) = -\sum_{i,j}P_{ij}\log P_{ij} \]

One can increase the weight of \(h(P)\) to make the distribution more homogeneous. This usually works better than Wassertein distance as we introduce a prior on distribution of matrix \(P\) - in the absence of a cost, everything should be homogeneous.

Sinkhorn-Knopp algorithm

There exists a very simple and efficient algorithm to obtain the optimal distribution matrix \(P_\lambda^\star\) and the associated \(d_M^\lambda\)\mathbf{r}, \mathbf{c}$. This algorithm starts from the observation that the elements of the optimal distribution matrices are of the form:

\[ (P_\lambda^\star)_{ij} = \alpha_i\beta_j e^{-\lambda M_{ij}} \]

with \(\alpha_1,\ldots,\alpha_n\) and \(\beta_1,\ldots,\beta_n\) are constants that have to be determined such that the rows and columns, sum to \(\mathbf{r}\) and \(\mathbf{c}\), respectively.

Input: \(M\), \(\mathbf{r}\), \(\mathbf{c}\), \(\lambda\)

Initialize: \((P_\lambda)_{ij} = e^{-\lambda M_{ij}}\)

Repeat:

  • scale the rows such that the row sums match \(\mathbf{r}\)
  • scale the columns such that the column sums match \(\mathbf{c}\) Until convergence

Code example is as follows:

def compute_optimal_transport(M, r, c, lam, epsilon=1e-8):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - M : cost matrix (n x m)
        - r : vector of marginals (n, )
        - c : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter

    Outputs:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
    """
    n, m = M.shape
    P = np.exp(- lam * M)
    P /= P.sum()
    u = np.zeros(n)
    # normalize this matrix
    while np.max(np.abs(u - P.sum(1))) > epsilon:
        u = P.sum(1)
        P *= (r / u).reshape((-1, 1))
        P *= (c / P.sum(0)).reshape((1, -1))
    return P, np.sum(P * M)

Reference:

https://michielstock.github.io/posts/2017/2017-11-5-OptimalTransport/

https://github.com/coderlemon17/LemonScripts