{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Introduction to Optimal Transport with Python\n",
    "\n",
    "#### Adapted from *Rémi Flamary, Nicolas Courty*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np # always need it\n",
    "import scipy as sp # often use it\n",
    "import pylab as pl # do the plots\n",
    "import cvxpy as cvx\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## First OT Problem\n",
    "\n",
    "We will solve the Bakery/Cafés problem of transporting croissants from a number of Bakeries to Cafés in a City (In this case Manhattan). We did a quick google map search in Manhattan for bakeries and Cafés:\n",
    "\n",
    "![bak.png](https://remi.flamary.com/cours/otml/bak.png)\n",
    "\n",
    "We extracted from this search their positions and generated fictional production and sale number (that both sum to the same value).\n",
    "\n",
    "We have acess to the position of Bakeries ```bakery_pos``` and their respective production ```bakery_prod``` which describe the source distribution. The Cafés where the croissants are sold are defiend also by their position ```cafe_pos``` and ```cafe_prod```. For fun we also provide a map ```Imap``` that will illustrate the position of these shops in the city.\n",
    "\n",
    "\n",
    "Now we load the data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data=np.load('data/manhattan.npz')\n",
    "\n",
    "bakery_pos=data['bakery_pos']\n",
    "bakery_prod=data['bakery_prod']\n",
    "cafe_pos=data['cafe_pos']\n",
    "cafe_prod=data['cafe_prod']\n",
    "Imap=data['Imap']\n",
    "\n",
    "print('Bakery production: {}'.format(bakery_prod))\n",
    "print('Cafe sale: {}'.format(cafe_prod))\n",
    "print('Total croissants : {}'.format(cafe_prod.sum()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plotting bakeries in the city\n",
    "\n",
    "Next we plot the position of the bakeries and cafés on the map. The size of the circle is proportional to their production.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "pl.figure(1,(8,7))\n",
    "pl.clf()\n",
    "pl.imshow(Imap,interpolation='bilinear') # plot the map\n",
    "pl.scatter(bakery_pos[:,0],bakery_pos[:,1],s=bakery_prod,c='r', edgecolors='k',label='Bakeries')\n",
    "pl.scatter(cafe_pos[:,0],cafe_pos[:,1],s=cafe_prod,c='b', edgecolors='k',label='Cafés')\n",
    "pl.legend()\n",
    "pl.title('Manhattan Bakeries and Cafés');\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. Cost matrix\n",
    "\n",
    "\n",
    "Compute the cost matrix between the bakeries and the cafés, this will be the transport cost matrix. This can be done using the squared euclidean distance. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C = "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. Solve the OT problem with CVX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# the solution is T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Transportation plan vizualization\n",
    "\n",
    "A good vizualization of the OT matrix in the 2D plane is to denote the transportation of mass between a Bakery and a Café by a line. This can easily be done with a double ```for``` loop.\n",
    "\n",
    "In order to make it more interpretable one can also use the ```alpha``` parameter of plot and set it to ```alpha=G[i,j]/G[i,j].max()```. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [str(i) for i in range(len(bakery_prod))]\n",
    "\n",
    "# Plot the matrix and the map\n",
    "f = pl.figure(3, (14, 7))\n",
    "pl.clf()\n",
    "pl.subplot(121)\n",
    "pl.imshow(Imap, interpolation=\"bilinear\")  # plot the map\n",
    "for i in range(len(bakery_pos)):\n",
    "    for j in range(len(cafe_pos)):\n",
    "        pl.plot(\n",
    "            [bakery_pos[i, 0], cafe_pos[j, 0]],\n",
    "            [bakery_pos[i, 1], cafe_pos[j, 1]],\n",
    "            \"-k\",\n",
    "            lw=3.0 * T[i, j] / T.max(),\n",
    "        )\n",
    "for i in range(len(cafe_pos)):\n",
    "    pl.text(\n",
    "        cafe_pos[i, 0],\n",
    "        cafe_pos[i, 1],\n",
    "        labels[i],\n",
    "        color=\"b\",\n",
    "        fontsize=14,\n",
    "        fontweight=\"bold\",\n",
    "        ha=\"center\",\n",
    "        va=\"center\",\n",
    "    )\n",
    "for i in range(len(bakery_pos)):\n",
    "    pl.text(\n",
    "        bakery_pos[i, 0],\n",
    "        bakery_pos[i, 1],\n",
    "        labels[i],\n",
    "        color=\"r\",\n",
    "        fontsize=14,\n",
    "        fontweight=\"bold\",\n",
    "        ha=\"center\",\n",
    "        va=\"center\",\n",
    "    )\n",
    "pl.title(\"Manhattan Bakeries and Cafés\")\n",
    "\n",
    "ax = pl.subplot(122)\n",
    "im = pl.imshow(T)\n",
    "for i in range(len(bakery_prod)):\n",
    "    for j in range(len(cafe_prod)):\n",
    "        text = ax.text(\n",
    "            j, i, \"{0:g}\".format(T[i, j]), ha=\"center\", va=\"center\", color=\"w\"\n",
    "        )\n",
    "pl.title(\"Transport matrix\")\n",
    "\n",
    "pl.xlabel(\"Cafés\")\n",
    "pl.ylabel(\"Bakeries\")\n",
    "pl.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Compute the OT loss\n",
    "\n",
    "The resulting wasserstein loss loss is of the form:\n",
    "\n",
    "$W=\\sum_{i,j}T_{i,j}C_{i,j}$\n",
    "\n",
    "where $T$ is the optimal transport matrix.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "W = np.sum(T * C)\n",
    "print(\"Wasserstein loss (EMD) = {0:.2f}\".format(W))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. Regularized OT with SInkhorn\n",
    "\n",
    "The Sinkhorn algorithm is very simple to code. You can implement it directly using the following pseudo-code:\n",
    "\n",
    "![sinkhorn.png](http://remi.flamary.com/cours/otml/sink.png)\n",
    "\n",
    "Be carefull to numerical problems. A good pre-provcessing for Sinkhorn is to divide the cost matrix ```C```\n",
    " by its maximum value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute Sinkhorn transport matrix from algorithm\n",
    "reg = 0.1\n",
    "K = np.exp(-C / C.max() / reg)\n",
    "nit = 100\n",
    "u = np.ones((len(bakery_prod),))\n",
    "for i in range(1, nit):\n",
    "    v = cafe_prod / np.dot(K.T, u)\n",
    "    u = bakery_prod / (np.dot(K, v))\n",
    "\n",
    "ot_sinkhorn = np.atleast_2d(u).T * (K * v.T)  # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the matrix and the map\n",
    "f = pl.figure(3, (14, 7))\n",
    "pl.clf()\n",
    "pl.subplot(121)\n",
    "pl.imshow(Imap, interpolation=\"bilinear\")  # plot the map\n",
    "for i in range(len(bakery_pos)):\n",
    "    for j in range(len(cafe_pos)):\n",
    "        pl.plot(\n",
    "            [bakery_pos[i, 0], cafe_pos[j, 0]],\n",
    "            [bakery_pos[i, 1], cafe_pos[j, 1]],\n",
    "            \"-k\",\n",
    "            lw=3.0 * ot_sinkhorn[i, j] / ot_sinkhorn.max(),\n",
    "        )\n",
    "for i in range(len(cafe_pos)):\n",
    "    pl.text(\n",
    "        cafe_pos[i, 0],\n",
    "        cafe_pos[i, 1],\n",
    "        labels[i],\n",
    "        color=\"b\",\n",
    "        fontsize=14,\n",
    "        fontweight=\"bold\",\n",
    "        ha=\"center\",\n",
    "        va=\"center\",\n",
    "    )\n",
    "for i in range(len(bakery_pos)):\n",
    "    pl.text(\n",
    "        bakery_pos[i, 0],\n",
    "        bakery_pos[i, 1],\n",
    "        labels[i],\n",
    "        color=\"r\",\n",
    "        fontsize=14,\n",
    "        fontweight=\"bold\",\n",
    "        ha=\"center\",\n",
    "        va=\"center\",\n",
    "    )\n",
    "pl.title(\"Manhattan Bakeries and Cafés\")\n",
    "\n",
    "ax = pl.subplot(122)\n",
    "im = pl.imshow(ot_sinkhorn)\n",
    "for i in range(len(bakery_prod)):\n",
    "    for j in range(len(cafe_prod)):\n",
    "        text = ax.text(\n",
    "            j, i, \"{0:g}\".format(ot_sinkhorn[i, j]), ha=\"center\", va=\"center\", color=\"w\"\n",
    "        )\n",
    "pl.title(\"Transport matrix\")\n",
    "\n",
    "pl.xlabel(\"Cafés\")\n",
    "pl.ylabel(\"Bakeries\")\n",
    "pl.tight_layout()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4. Qu'est-ce que la bibliothèque POT : https://pythonot.github.io/i"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Regardez https://pythonot.github.io/auto_examples/plot_Intro_OT.html#sphx-glr-auto-examples-plot-intro-ot-py"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "livereveal": {
   "header_not": "<h1>Introduction à Python</h1>",
   "scroll": true,
   "transition": "none"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
