{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Optimal Transport between empirical distributions\n",
    "\n",
    "Illustration of optimal transport between distributions in 2D that are weighted\n",
    "sum of Diracs. The OT matrix is plotted with the samples.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Author: Remi Flamary <remi.flamary@unice.fr>\n",
    "#         Kilian Fatras <kilian.fatras@irisa.fr>\n",
    "#\n",
    "# License: MIT License\n",
    "\n",
    "# sphinx_gallery_thumbnail_number = 4\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pylab as pl\n",
    "import ot\n",
    "import ot.plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 50  # nb samples\n",
    "\n",
    "mu_s = np.array([0, 0])\n",
    "cov_s = np.array([[1, 0], [0, 1]])\n",
    "\n",
    "mu_t = np.array([4, 4])\n",
    "cov_t = np.array([[1, -0.8], [-0.8, 1]])\n",
    "\n",
    "xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)\n",
    "xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot data\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pl.figure(1)\n",
    "pl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\n",
    "pl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\n",
    "pl.legend(loc=0)\n",
    "pl.title(\"Source and target distributions\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a =   # uniform distribution on samples\n",
    "b =   # uniform distribution on samples\n",
    "\n",
    "# loss matrix\n",
    "C = \n",
    "\n",
    "pl.figure(2)\n",
    "pl.imshow(C, interpolation=\"nearest\", cmap=\"gray_r\")\n",
    "pl.title(\"Cost matrix C\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute EMD\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# T solution du problème de TO\n",
    "T = \n",
    "\n",
    "pl.figure(3)\n",
    "pl.imshow(T, interpolation=\"nearest\", cmap=\"gray_r\")\n",
    "pl.title(\"OT matrix T\")\n",
    "\n",
    "pl.figure(4)\n",
    "ot.plot.plot2D_samples_mat(xs, xt, G0, c=[0.5, 0.5, 1])\n",
    "pl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\n",
    "pl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\n",
    "pl.legend(loc=0)\n",
    "pl.title(\"OT matrix with samples\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute Sinkhorn\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# reg term\n",
    "lambd = 1e-1\n",
    "\n",
    "# T solution du problème de TO with Sinkhorn\n",
    "Ts = \n",
    "\n",
    "pl.figure(5)\n",
    "pl.imshow(Ts, interpolation=\"nearest\", cmap=\"gray_r\")\n",
    "pl.title(\"OT matrix sinkhorn\")\n",
    "\n",
    "pl.figure(6)\n",
    "ot.plot.plot2D_samples_mat(xs, xt, Ts, color=[0.5, 0.5, 1])\n",
    "pl.plot(xs[:, 0], xs[:, 1], \"+b\", label=\"Source samples\")\n",
    "pl.plot(xt[:, 0], xt[:, 1], \"xr\", label=\"Target samples\")\n",
    "pl.legend(loc=0)\n",
    "pl.title(\"OT matrix Sinkhorn with samples\")\n",
    "\n",
    "pl.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
