{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic Regression\n", "\n", "The following script generate a fake classification dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJztnXuQXUd9579972jGUMHlypgtb2zLCuzyUPDGNhNlVRBrsmYVbGxwoRQVliBig+XBhqAA0UYxiUahIhckgAC7zMiv8tSSsFSZtTcEl7GNFbuiYWPJj1A8QhkqeLEhOEO5WIpY0mh6/2g1p0/ffp3HPefeM99P1ak7994+3X2OdL/9O7/+9a+FlBKEEEK6Q6/tDhBCCKkXCjshhHQMCjshhHQMCjshhHQMCjshhHQMCjshhHQMCjshhHQMCjshhHQMCjshhHSMiTYaPf300+WGDRvaaJoQQsaWI0eO/KuU8sWxcq0I+4YNG3D48OE2miaEkLFFCPG9lHJ0xRBCSMegsBNCSMegsBNCSMegsBNCSMegsBNCSMegsBNC6mN+vu0eEFDYCSF1sndv2z0goLATQkjnoLATQqoxPw8IoQ4g+5tumdYQbWxmPTMzI7nylJAOIgTQgqasFYQQR6SUM7FytNgJIaRjUNgJIfWxZ0/bPSCgsBNC6oR+9ZGAwk5Il6CwElDYCekWVeLIOSh0Bgo7IURhDgoU+bGGwk7IuFM2jjz0PVeQjjUUdkLGnfl5FTuu48f13zFh37vXPyiktls3fFKoBS5QIqRLFFkgZJf1CfqePW7BHcZiJC5wCsIFSoSsRWJx5DG3Tczyp0U9FlDYCekSKX51n3inLC7au7f+vDDMNVM7tblihBB9AIcBPC2lvDRUlq4YMlLMz69NEQm5PXz3xDyHrpjGacMV8z4A36yxPkKaoe4IkHEZJEIWuu1+KTvBmgIt9tqpxWIXQpwF4A4Afw7g/bTYyVhRxEpMse5T6hvXpwRa7K3StMW+H8AuAKs11UfIcClrJdZl3afWU1T8x3GwqBveg+rCLoS4FMCPpJRHIuV2CCEOCyEOP/vss1WbJaQaZWO/Q/UNw51QdCDR5Yclblu2DNdtUkd2SC6uqu6KEUJcD+DtAFYAnALgVABfkFL+ru8cumLISBF7/J+fd4tF0fjuovWk9M1XvgmXxqi6TUa1XzXQmCtGSrlbSnmWlHIDgN8B8JWQqBMycqTEftdh3afWU9T6d5XXnxdlXN0YnIDNI6Ws7QAwC+CLsXKvfvWrJSFjCRAvs2dPPfUUKafbzYaL7Ejpjy5T5PrselPaaYLUax5DAByWKVqcUqjug8JOxpa6BMNVjymYMYH29UMLc5EBIeU8sz1fmaJtDgt9vzoIhZ2QUSF1MHCJUVERLWJ5m+fEBhKzvlEXdn09HSRV2JlSgJBh02SUhvYpp0SXaL+03b89e9y+f5cPe3Z2dHzb9vWsYT87hZ2QNolN+pkCXWSCMEXMXJO55rl2eyb6vNnZeDtNUXcI6xhDYSdkGKSKcEyM6halIvX5+pZSZg2K6ShBYSdkGAxD8OxzbQs7BZ9bKDVvjF12VAV8y5bi54zqtZSAwk7IqOATV59LpC6/tl445atv715/38yBouqq0TqF9e/+rvg5XVqxmjLDWvfBqBiypqgzptqOWEmNOw9FvcRCGfVnrjarRJ/Y7dcZyVKmrjGIpAHDHQkZc1wDgh2jXTZe3feZKeK+wcD1eZlFQXbsfNWFRWUWaBU5ZwQWPVHYCRkVyq7QdInwnj3ZUaSuUJ3mZ77BJCToqdh1h+ote31m3XWfMwIWPYWdkFHBFoSYQKS4Joos/Xd9lmqpxgS9iNjFBgjXQFFWTCnsFHZChkqKWBUV3BSRSRWiIgNNFReM7z6E6rXPSW2rjKXvGwjrcDvVBIWdkDZJsUxNgXD5uV2iVkRkUoR9y5biA4A+J9UvHRJtU7ztz1KEvklosVPYCfk5KRZ7TPhDFrtZpoz4l7GCy7pi7Prs9mP3isJOYSdkJPCJU8yqTxU8n8imiLZrgPDhyi5ZxFXiG9BC58cs+CZhVAyFnZCfE4uKsScN7ff6M18dIWF3ianPlRKzSF19KOL712WLunLMeuuwmkdAoMuSKuyVt8YrA7fGI2uO+fn8YaJXfO7Zk7b6cc8e4ODB8OpKvQp0714lmz7MbeRiW8rZ37vKp2xLV6RN+5w6tr0b463zGtsajxCSgBZsc7NpO02AvXTftGfN94ASdfs7/aoHCF/6Wl+CsiJlzcEpVKZqbpy6tvyLtTMKddRJillf90FXDBk6o/C4HXKX2N/F/MghV03ITVHWFVN0oZLP5x6bT9iyZbAdF1VdMMOOtGloYhX0sZM1zQhEMAQnRmNi6PLLx8TRJ8ahz+y/7T5VCdtMaT+V1HPK7DfbQWGnK4aQujEfy7X0AZmbRb9PrUPj8r9rf7t5js81EtsUw+Xjj+Vk1+/NjUHKth/rWwgzE2UKRTYtGWYdwyJF/es+aLGToTAKYXGpFq7PBaLrCOE6J9SXUL9MV0gVd1DZe1DXv03IHeXqn+vcOtofMqArhqxp2nTF2CJTt9AVKedyu7gGkpS+FAlTtNsc5qBbtG4KO4WdjDhVltPXScgiLmO1hzDF0u6DXa6oGKc+RbgGr1h9ZQS9rpS7rr6ntJFKQ0+FFHbSXUIrMl1lhklowjK0EMn3dwo+yzsm5CERLuKGMeuKuUDqTkMQK5dSvs2nuYpQ2El3cbkX2sLVvq9PtlAWsWZTrFOf6yXWrxT/u11vHf2OPXGE+uwr58utMwrzLzVAYSfdperjfd19sUnx7drClmppFgk/LONycD1JhIS5yL9DbKAxxbnov699b0P3s21joAIUdtItQoI2Kn2pMrAUcSHYr1UX/9htmKJtfp4ymKT0P/W92Z+i10Bhp7CTMSPkbhgWPjeBKTxVhN2VOVH/nWKR2+VS2zLfx54GbDeS/tt81fX43DJlFjyl/hunDrhj5n4xobCT7mKLSCqpj/KhNl0+bZe4VcFsK1XYzc9cfXXVr8ukuFxC5VxuJftepPjPXfcvdcl/1ZDH2OcjAoWddJeyP74qj+emeKcIbRV8A4Ut3mXdIy4XS+ja7P64omL052Y9vjkEV6SNq0wRoTbPT7Hyfd/X9W84JCjshNgUFfZYREdRQQ2RGvVi/p0qxKH6Y9eQWi5m0dt9133Sr6Et+uzyPszrLurvj30+IlDYCZGynDjbYhqbuPVZ2HY/UvC1s2WL392RIsA+F1LdR2r9Kf8+Zn9997eIZR+aZE6to2Uo7IRoUoTXFhv9GvrB2y6F1PpjfXUJYGyAcVnEKVZpymBVVcRDR4qo+hZJma+h643dg5TPRwQKOyGaosLrEyGfHz0lKibUdkoop0twze98TxmuCJUiaQ5cYm9eq6/8Oee4+2/3zyf2qffG14aNfQ8o7BR2MoakPKL7/g4JatEonFgfNEXEtqy/3O5bUdeEqx7zvoTK209Bdh32QGYTG+jMNnzn29fvYgTdLyYUdkI0KdZZzDKM/eDLTtalxKzH3BUhofSJZqqVbtafGu8eumdm322rP7SgqszgZZ/fARoTdgBnA3gQwDcAfB3A+2LnUNjJ0EnxvaYKfgox4Ujtgy8WPcWSLjKJalq8qWVdfSwqtL4BwBUVkxLDH7rvZUImR5wmhf3fA7jg5N8vAvBtABtD51DYydDxiXPRsMIy7blIWSykMUXOFD9TDO3rSxHq0KF94kXFOqVcLBQy5R7a150i7K7zx5zWXDEA7gbwX0NlKOxk6KT8kH1lUiy6lHDJkB+3iLialrNtRRexwPXhs/zNz133KeZaSRHu1CN1Z6e6nqhsqrrehnRuK8IOYAOApwCcGipHYSdDoeijd1krzjcBWKZ+s4zLOq966Ph3s52irhTbOnZZ4/p718CQ6iJy3QuzjtDnMYqKaezfrsoTQIVzGxd2AL8A4AiAN3u+3wHgMIDD69evL31hhCSR8uMpazmFrNqiffCJmS2mZpsuH7OvnK/umE/exCXwqWIdK5faB9d1uN6nUnayu2q7Fc9tVNgBrANwL4D3p5SnxU6GTpUfno8Ul0DRpwafW+Scc7LrKHvYlrVvoU/M4o9Z8/ag4RpYXN/5BiVfWy7Xi+vfOWXAtu9B6F6U/be1+1T23Fy3m5s8FQAWAexPPYfCToZO3ZEPZX6YRQcXl1WdKuA+odWfl12UFLv+IoNLEZF0DQ72oBA71z4/dK9D38fOL0qFc5sU9tcCkAD+EcDjJ49LQudQ2MlY4xM/Vzmb1IHAPjdVgG2xt61hu27fxGxI7IoIe6rF77uuVAEOuZ3Ma025h762inw/pHO5QImQYVDEcvdFkbi+M2O5XXX7hDdmxcYs6RSx1m4hexAo2p5PQGMDTtX+u0it18daiopJPSjsZKxJ/fG7cImY/d4n3Pb3Zaxg05KODSS6XbNu289t9jU0OJmx+THLuYhl7fq3iN0fV3mbMv+2DUBhJ2SY+CzQUNkiQlPEIteU8aPHLF3XexPfYirXNYXi4FMsZ7vtlHvi+zcq47svS43zPRR2QuogNPlWNIa6any6zyI363a1Y57jOt/VTmglqu/67O9iopxSxvwsdX7CN2D6CLnM6qDGuijshNSB70cZ+rH6QvJCFmHMYjfLxaxssw5zUEm16FPLpcSpx643JMKpg2fo6anKWoW6oLAT0hIpYXGpj+o+y7VKyKCvH0VCIn3n6z66JnRTrtd1nfZ9cPXfd36oTJX+FaEO98sQ+kZhJ6QIRQU8JDohq9es29d+GfF39dNlFdttxK7HHlhi122/Fu1v2fUCdVrYdUOLnZCW8IlgSPBMYm6WVBHWf1cR9RRhNKNb7EgXu1zo/qSWdV2flHG3j/2kYf49LGu9bijshDRIFcs8JIQua9h+7xLaUJ9cFr6vTpfl7OtzanRPTNBj9zIk3vY1pQx+JqMq6BpGxRDSEj5XQ5EfZREfeqzuVKF09RNI2yTarqeuiUaXAMcGIfNv++kjVG/o8w5CYScklTof621BcglYzAJ1Wd6uASfVNeESySrXnDpZ7PssdfALHUVDGjsChZ2QIsTENnZuSFR99cYs0FSr23UN5sBSRDBTMK/N1beYW8ncRMN3X1Ija+oYjMcICjshZUgRN9+Ennm+byKxjF/fZ3W7+p5i6doDTtHBo4zFXESUfWLuui+p7XcECjshZTCtSR8hsYlZs756YtZtilVd1Eo3++u6FpOUp5IQRYQ9FAnjqpPCTmEnJEhRkfJZkSn1+b6z6w8l1irqakm1lEN9tgcYlyDbxBY1lcF8AqnCGLlvKOyEmKT+eH1CkyKgvoVARftkC3uRtooKve0TL3PtVVwioftdpC9VxHmMLH4KOyEmZd0MsbpiolBGcGKLdnz4BoSQMKe6Ulz1htp34Yuj97WT2peqVKmjYWufwk6ISeqPt6orpm50/VqA7TmAIql6TYEvuwjLFvYikTtFLGuXq8d8X9Vqr8vyb9jap7CT7hP7EZb58ab8UEMTeikUOUf3J8XVY/fdzJWu27UjbEL3xh5UYk8RsXsXs7BD7YTKV6VKHRR2CjupmaIuhBSaeLQu0u+U5fQx69xu167LFPBQP311hoS9rGWc+lTUhrC3GENPYSfdZxjCXgdFRCtWj0tAUjfBMF03qYuoYuVih8+PXubJKeWcOsR0jCZeKeykm5S1lpqc5KpquRZxudh/lxFg31OB670dXhhzldjnx9wxrnMbFs9CNGSpZ81R2EnXafoHXzVkssr3ITF3Heakpn2e+VnMQk9ZJeq7HrNcKBY/xigLe11+/kQo7KT7NP2DD7UXWzmaWo/v+xRh1OfFJiFD0S72Nfn6Zca+h0IWXXWmWuxlnszKMibuGAo76T5NrxhM/QH7rGuNq9+hsMFU14ct3PZkZ4pPO0Tsuuw2fYJcVAibEM6ibbQ0gUphJ6QOqoZMFh0MXBatyxJ3YeaVMc+tIsJ2ebtPZepNcb+YjKKwS1ltxW1JKOyE1E3qD7houl2zbpcQ+9wqMRdI0XNilFmhW4foFVmxWrTeKla3+W/WEBR2QurG5WMuco5NyC9fdrFOkXOqXFuKi2mYE4t111umvhQXV81Q2EmnOHRIyn371GuVMpUo4yeuYgmXsbztNlOFvei1ucr4xL6o6yWFpoXdNW/RoG9dQ2EnneHQISlf8AIp+3316hLulDK1kiosRSz7kBtG/13UmtbWfxGKDgSx8+oS4WGKaplFZQ26YLImKeykI+zbpwQbUK/79pUrU5lhCEtRaz21Tk3MXVDVEi0TBVOHEDctqhR2Cjupl7G22IvW6YodN9MClKnTfPV9H/usSFtShuPu67h3TYhqbNBqOtxWUthJxxgJH7tJyuRhHXUWsdTNfqRa36ZFX1XE7AlYX311iPKormEYMhR2QoZJER9zlTqrCGEo6iYk4j5ffwzfJKkp6C1NOlZmzIRdqLLNMjMzIw8fPtx4uyRjaQk4eBCYnQU2b267N4pR7FMhhFByVQfz88DevYOf79mjvnOVd32u++Trm+tzIdSr+XnKtaW2Ued9agrf/W0YIcQRKeVMtGCK+td90GJvl6b80UVcI437yOuiCUs0xVpMdYOUOS/12nz9DOWcIYVAk64YAK8H8E8AngTwR7HyFPZ2aSKCxBbqhYWwyDcS1TJshuF399WbWiY1KiZWX0i0iw5s4+J+GUEaE3YAfQDfAfASAJMAngCwMXQOhb1diljHZSckTaHu9aRct27EolqGQd0hcbGJy2Ek3PKdV2VwIbXRpLBvBnCv8X43gN2hcyjs7ZMaZVJWbM1zJyaUuMes8UajWoZB3ROqRc5NdYOk4koPULYPdbRPpJTNCvtvA7jFeP92ADeEzqGwF6MtwavqHtH9XljogDVehLr87nUIe5M0Pa/go8ODwsgJO4AdAA4DOLx+/foGbkE3GLaLIjRo1Nn22FvjZRlmDLp9Xpdo6klnzKArpiMMa1Lx0CEp5+aknJqK+77XpCDXBQUqnTaedMaMVGHvlYmltHgEwH8UQvyyEGISwO8A+N811EugYronJ4F+X73Ozlavc2kJuOgiYGEBOHoUOHECOHZMxZDbbN4M7N6tXpeWgOuvBw4cUK9LS9X70nn27Ble3b646hGIty7F/Hwm50D2d8r1zM+r+Hgdg6//Htd7UZUU9Y8dAC4B8G2o6JjrYuVpsRejbqvZfAoApBQiLTrmBS/IJkF7vTXiL2+TMhkHQ5+PE+P+pDMk1xgatNghpfySlPJlUsqXSin/vI46SYZpNdeB/RRw9dXAAw+469dW+uKisupXV9Xnq6t+K5/UxKham030a5hPOk3gWjXcILUIOxkdtBCH3CSbNysh//CHlTDfdJNf1C+6CPiTPwFuuw2YmAB6J//H9Hr1uYZIAXwuh9nZ5lwRTYhWlX6P+6BQA8wV0yG0EB87pkTXZ4Wncv31StRPnFDW/VVXAevXA9PTwPJyvTldxj5PTBsUyf/SRLtrnaL5fUqQmitmopbWyEhw8KASdXMytIxIapGdnlYDhB4otm8fjujWPSCRIWCLln4yqFG0xh4zUVjLgx+FvUNo37kWyDJuEltk9+8HHnusfJ9SLPG6BqQ1h8/lMAxXxAiJFolDH3tLpPjCi2L6zm2rN7U9W2Qfewy44w7g5puV4Bfprx4kPvQh4MILVZiki2GEdK4JuhbuaDPO19Gyn58Wews07Xoo0p5t9QPFrWltpT/1lIqTX11Vx3veo763/fN6QKKPfUxoSrT27h1fcW+53xT2FhiW68En4EXas0UWUBZ7qnvH7MPEROaKBYCVFSXuq6v5/nHidMwYV7FdQ9AV0wLDcj0cPJitJD16NIsxd7UXcs2YcfMh946vD3oQWVkBLrsMWLdOhUf2++pzc4AxQyqLunpIB7AHCa4grQWGO7bEMKzUAwfUYiPNwgJw7rlZhIt2gXzta8C112ZhjDfeCOzYUU8fXE8NQNaHnTvz3x08mA+p/PCH1aBC1gihiVhO0g7AcMcRR1vDdbK8rCzj1VX1+thjg0IKKHfIyor6W7tHzj23nv64/OXaCj/3XLcvvWokT9useVdSXfuBjsi+op0gJe9A3QdzxcRZWJBy61b1moqdZndubjAz5L59Wb4XffR6w9v8IiX17zhnkOzEzk9VqSs1sV1P11IR1wASc8XQYh9BTJfKl7+sXlNcJakTn1NTwPPPZ5vXT025LeUq0TtmZExs4la/13MCoTZGzTpuNQZ/XC1cX0y8OdOuy5FypKh/3Qct9jBbt+aNmK1by9flsoYXFtR2dUKoV/OpwCxfJhe8zvM+OanOmZxUOd/1vqeuJ5BUq7fqVn3DeCpo1WJvM4thnbnT66hnjQBa7KPH0pLKkgiEl+dv25ZZ6vp9WVy+/OXl/K9oeTnrn73qtIj/W5+vnwY0l10GfPGLyqrduXPQn59q9Za1joe5bmDNxuDXtRLVTEnAydLaYLhjQxw4APzGbwCf+Yw6fvM3/aF9O3aoiJatW9VrXRErGl/44/x8fuON5eVyoY7mk/XkJHDGGeozX6pf3Z9eTx3T0+pzOySzbJioa0Cok7rTKgfpWjjguPZ71Ekx6+s+1por5tAh5YYw7WQhwq6Noq6DKuXr2kRjYUFdZ6+n3C9zc/n6Q+4K7R7Sbe/alX+vzynjUunsBGfMFdOUO6Ouduh+iYKm9jwtc6w1YXdFokxN1edLripcZv96PeXTL1qHOThMTChhNgU4JsimP7/Xy9+vUNROkf6Na+SNl5iwj8JOQqRWUoWdPvYGmJ1VkSdHj6qn5ssuA3btKp7t0BcRUrS8/fn0dH5npG3bsvKpUSi6D7qej39cybLp0w71wcxRI4S6Fk2vVz2+fRjrBlqHG0oQDxT2BtATbK6JU5dwutLvhiYAXeUPHFCrS1dX1aCyf7/ymbtWf9oLm0KTqboOO5GX2YdeTwmz6VM3BxpXH8xJyOlp4L3vBY4fV3XdeGPxMMs1MZHp8k8zbzoB6IppCpe7xPYrm24C23UQCz20feamT18I9b7fz3zgZj0+V87cnDpXu0N0X31+eN2HhYXB+sw2XH2wr2VqSrVtu6xCLpXO+tLLQldM5wBdMcOlqGVou0sWF4Fbb82W9uukXWYqW7Pe2CYaZvnrr8+7MrRrQ6fP1Qm5dD2+NAC33ZaPcNHnA4PWuN0HM0eNvVBJStUHHTVjX8vBg+q+SKledRTL4iJw++3qM1fYonmPn39ele+81U6IAwp7CcrERbvynNviq10urgHDduf4+qXFVPv0+33gLW8BPvc5VUbb8a99LbBxY75+WyR1//S8wL33ZvnVAZWW1xZlM1b//POB3//9bG7BHFC0W8c1MOp7dfSoOue55wbj482wRX2/ZmezDJJSqkFgWNv5jQX0wa9dUsz6uo9xd8Wkrsi03Qa2u2RqKnOX6FWZIVdCyNVgf7ewkHeLaJeKnSMmlL/F5U6Zm8tWrU5ODvbBvCY7EgiQ8vLL01wkppvKdAFp15K+RruPpvsodbUsIeMCEl0xXKBUgpSFMq4843ae8yuuyOa2VlaAP/szZZWa7hpzgY7L1aCxXT3Ly6qt5eX8oiET34IhwJ2HXVu+2iI+flxNgr773dnTwrFj+fptfvazNAt6eTlz/eiUvvp+X311NulrRwNt3w6cckp67nlCughdMSVIWUaesvz9/PMzwZUSePrp7DshBv3Js7NZxImUygeuXQ3T0+o7HWKoBxvTrWELba83ODDZriA7RPHWW/N9/od/UMfttwOf+lTWFqA22NC+cs1558XursIMwZQS+IM/AE47bfB+2/MOrkRoTW5DSMgoQGEvSSwu2vSp9/tq8lBb7ZrlZX96jJe9DPjmN9V3emCYnc2LszmxuHOnEvxeT4UKmpkSH3hARbqZ+WcA4CUvAd785rzF7gpv1IK5uKisdBf6KeHBB/NhnYuLKoUCoPp22mnu880BBQDuvDO7N/o8ewMO3wBrTyS3ln2RkLZI8dfUfYy7jz0V7fOdmhr0VWv/t51qQPumtQ8fyHzZ9grWXi/zpZurNk3ftM6maPu/7fQGdv52fb7ps5+YcJ+v/dkpmRt1f33zA2Y2yCrpDXzt1xEC2clVrGQsAFMKVKeOH7A90To3lxcanRNFiMGl9OaEo+6PORHqmkRcty4/UarFX58/Nyflpk2Dk6m9XpZuV/fFnITctMkv6rp8KK7dF98uZX7CU4h87LxOb1Dl36JOIWasPGmTVGGnK8ZDXaleXWGOpmvgtNOAhx7KYr1vvnmwjmeeySb+3vEO4NFHgUceUbJ69KjaAu8d71Dfn3++WnGq4+NXV9V7nSpXx6hv2ZJ3q/T76lw9mWu6fIQAfvpT9/X1++7MjYuLag7gxIn8/qa2W+RrX1PXrN1RepJUzy3oxZJV/i3qTCfQ6sYahCRCYfdQ1w84ZVcjU3DvuCOb6NQ+5kceAS68MBM8PUkKqHK33qpeJyeVX/vGG1WkirmYyF5IdMMN+Q2tb7hB+cjtyU7dh299a/DazPh2fT3T0yolsZ5ABfJzBOYgNz0NXHNNPl7+Xe9S12D6zUN+8tSFYnWlGogtFCNkJEgx6+s+xsEVM8xHbp9r4NAh5XbZtEm5aHzuD9NdYf89N6fqCqUr8PXDTt9r1q2PDRvUPTHrNeuZm3P7712x/JdfPuinL5IqIPT53FyxtMF1/PsRMmxAV0w1YiGNVSxAl2tgaUnVpePAn3gCuPjiwXP16s3VVSWH69YpS1Zb2rffDpx6KvD448D73+8OETT7bkaamNf83HPAxz6WXx0LKLfQBz6g6jezQOpXe2Xsr/2aiq4BlOWt21xaAv7mb/JlX/aywf7pe+X6t/BtoGE+Mdx2G3DllfW6TzqZKZJ0ixT1r/sYB4s9xDCs+X378taxEMqidVnMk5P5qBdz8jG20jO173Z/zElWnVDMZbEfOpRNwOpoHlebrhz1QqgyZjRO6N766rXvoz1hTUubjCugxT48yvrfQ1b+7KyyvrXFPjEBfOlLgzHu3/telpBLCDVxqsu7Yszvukv5wB94QL3X29+5kniZTE+74+vNhGJHj6r4+SeeyCY79+/Ap+qgAAAUz0lEQVRXFjKQLZ66/vqsTZ3szMxRD2Tybj4hxO6tz5K3F0lt3z7otzdZU6l+yZqAwl6C0ARaaHMLV2SHWf7gwcyV8cMfAnff7W5fZ0acmABuuWXQXWJz7Bjw0Y8Cf/u3+clRcwOLAwfUoqBt29Qeq/fcM1iPKxmYjs4B1GfXXKM+14K6tKRWppoTudPTg/nX3/OewYFpdVW5hOx7Glodu3nz4CIpfY6LYW50TUhrpJj1dR/j7oqR0j2BFnJzuBKHhSb/Jiczd4J2TZj7gW7aJOU557jdL/Zh5j+3j4UFddifuSZudcz83JyUr3zloMvDXFSl3UB2AjLfNncLC/4+mq4fHQtvL8CK/VsV+XchZFRBE64YIcRfALgMwDEA3wFwhZTyuRrGm5HHNYEWctG4Qv18bhFz6b4dAjg9naXCTcF0nbi49Vbgxz8e/OyUU/KfXX65mszduXMw70yvB7zxjaqehx7KPn/mmXwCMiGU+8X1hPPcc8ArXqHCKu2+6ntx7Jh6qtDtr64qS1/H6Pso8u/C8EXSBaq6Yu4DsFtKuSKE+AiA3QD+e/VujSc+kdDiZW4tp7d+kzJLxjU9reLPzURb2qWhBxId0x1CCODMM4Hvfz+ze30cPjwopEeOZO6dXg9461uBX/kV5c93JRNbXVVuHrOddeuAd75TLUDS+XKuvDLvGtEbZ9jZJ887T0XdmHXpOP1t24CvfCXrw4kT8TmOkHinJHQjZNyoJOxSSjOt1FcB/Ha17ow3vp2IbB/u4mJenGdmlAju3JnfTEII4JJL8gm9ZmfzE6V6pagpjFICP/hB9l5vPuHCFOlf/EXgVa8CHn44//3nP59N1vos/+PHs74AwBveoHz1eiclLaZ6tal9rSY//ana7FuHVJq7MS0vqzDOj30s8+XHrOyYeDN8kXSNOidPrwTwP2usbyxx7USk3QBHjyr3ywtfmD/nggsG86brbeO+9CUV760HBSAvhr2eEnrb6jWFfHoa+NGP4n3/8Y/VQKRXuOp+2KtRfZiZKu+5J8tmaQ9wZuphF9/5DvDpTw9OZOrzJybcA1oIijdZS0SFXQhxP4AzHF9dJ6W8+2SZ6wCsAPhsoJ4dAHYAwPr160t1dhyx86Hff78SJp2rXKcBAPJpfnXI4M03533DTz2ViS6g6rziCvX3rbe6Qx5TRF1z/LjypwNqQNELoXzphTX9PvCa1yhrX0pVz/x85lZ54Quze6DdT0Ko8y65RLWlByMp83vALi3l5yPMQSvFFUPIWiMq7FLK14W+F0L8HoBLAVx0ctbWV88BAAcAYGZmJtHOGn/MfOj3358J01VXAevX510Dvg2lV1eVAM7Ouvc8PfVUtcJ08+Zs8jIkxL2eOswBwuSMk8O4KcIzM0qgTUvebEdK4O//Pp9E7MtfHswBr7/74AdVn6ens1w3dpnp6czSd7ltXBuFEEIqumKEEK8HsAvAFinlz+rpUvfYvFkJ+8MPZ7521ybLprtATy6arhlAZWA0kVLFqNtMTKjvXOL98per4+67B8VSCOAb31Dtm/HuF1ygBPTjH1eiOzGhomSeecY9ARui1wN+8hN17Nnjfsro9ZR7anER+Ld/c3//utep+0prnZA8ImBkx08W4kkAUwCWT370VSnlXOy8mZkZefjw4dLtNkkdqxJ1HXryLyUToW2l9vtq/1EA+NCH4kJ6+eVqAvKd71Q7MZlov7x2mcQQIjvMaBk9KevqS+/kbrrmdzrHjU7N69qLtd9Xr3oVqyu0UwgVjsnFRGStIYQ4IqWciZWrGhXzH6qcP+rUsSoxtuJUi70p+ouLg66HiYnM5TA15Y8o0dxzj0qq9aIX5ZOGAZk76NWvVqtCY7hCJs34eCGA008Hnn02+/6DH1QW+S23qKeGfl+Jve6HDvU0mZxUk6b6Phw8mH/iEELNTZhhk4SQQZhSIEAdOdl9GQgvvDDvr9YTlNpitUVPv9c++8VF5Zv2RawcP5530fzqr6qcLhqdTqCoG0Vz9tnKDbOyos5fXlaDzwUXqKeEHTtUzL05mGhRl1Jdp06LcPHFyq9vivXSkpoo1vdDT7La5QghDlKWp9Z9jEtKgTqyOLrqsPOQpxx66zo7f/rWrfml+P1+ln7ATgdgvn/lKweX+/f7qm8bN6b1aWpKpR7QbdlL8s387nrbPp0dctcuf05z855NTanr9m2r5zqXudJJVwGzO1YnZVVizAfvquOZZ4r3ZXVVJeoClCtGu3TOO09F22i/s17devBgPiLFnoB8+cuBb397MPb94ouVb97eBcnF8ePAL/1S5k93rercv18t+zd3STpxwh2nrjGfcgAVPaTj/ENPT0zoRchJUtS/7mNcLPYYtjWuN22OWYt20i3XIYSUZ57pz4muN4c2P9+1K+vXunXu8846S1nl2gK2y2zdmtVhJwI744zB+nRu+HXrpHzb29T5ZmIuM8mW/XSgrXvfTk52zveYxc6EXqTrINFip7BXwBQS1wYUIXbt8m+OobeT27VLZTC0v9dbyG3dmv9806asX7YrRrtD+v1sE4y3vW2wbj04SKnamJpS5detG+xLr5e1Y7e3sJBtUTc1Nfh9aBMO3bYre2Zo4BzmdoaEjAKpwk5XTAXM5FJmlEjKhssf+YgKSXzXu1TcuOa884C3vEVFyezcmd/UWnPZZaqebdvy7pYjR7It9nSaASDb/NqcJD12DPirvxq8pk9/GnjpS7PIlAcfzFa83nzzYHk9CWr2D1ATu2YCsJmZbKJWCBXZEtuo2iaWFoAJvQg5SYr61310xWKXMrMifZN7MXeNKxe6lPnt7rS1bW43p7nwwvz5l1+e5UsvOkGr23JtTaevw3zKmJzMrmfXrsF+mNa8b3u6sm4XQtYioMXeDKYVaWYydG24fPQocO21Svr05J7e2k7z2GNZKgFtBa9bB3zqU5kVDWQbQ2/cmM+Bftdd+fpiOV7ssuaiI9OC1hOh7353Vp+U+VzoL31ptguT2ZfVVbVi9vzzs+/NTbBtK7uIFU8IGaTXdge6xObNwO7deRHS7hq92lIvDjJj2m0OHsxHkVxxhYoL371buTcuvBC47jr1euqpKkrGxyteob7v9dQAcd55/rJnngnccIMq3+8PRrk89ljenaMTfelt53bsUNvm7dihBiEdLSOEcs28971KxHfuzG9Vp+8boER9ejq7Z8wFQ0hxaLHXhC/scfNmJWhf+ALw67+uXu0NH/RmE65Mj+ZnS0sqdFAvbFpZAT7xCeWT/6wnr+aTTyqx1qtbr73Wfw3/8i/KAi/ip77vPrXxxY03KkHXmJtVr67mV7iaVri5AnfnzuyaddgmfeWEFIfCXgLX5sq++OkDB7IVoE8+qazsjRvzqyf1BGUs0+P8/GBul5UV4K//2t/X48eVpX3TTcqN4svoCKinhPn5zJVis317NgjpTTekVHXqLeqArN86q+V99+XdQTpTpXnfdH3aBWSmFQAo7oQUYayEvWgyrWH1wRbxUOqBO+/Mn//QQ8Ajj2RWOOCO9rAzPV50kXvBkMuHnupX7/ezaBktrPfdpyJter38QijdJz0IaetfDxQnTqg0B3fckb838/PKojfTJ7giYnRSMb3ByPQ0FxsRUpaxEXZT3FZX3cIz7PZ12J8t4qE9Ne2QRECV++hHgZ/9TH2vXRjmkwCQ/a0HDn3dgBLkfl9tE/fJT2ai3+upBFyf/OSgK8e0uCcnswnZ6Wk1AJmWtStsExgchPSqUu3nt+/N7t3KTWOW0/2x75vpfqkjTw8ha5WxEXZT3AC/8AwD00qfmMinltVPDS7XiX6/sKAmD/XkoxBZxIgW/XPPzbchpRI1nfxqYiJr0/Y//+Qnqg0pVd2nnZa3rM1JWr3bkiuRlj0AmRkl9X0wr9G1p6lpsevP7HLm/Qn50n2DJSEkQkpMZN1HmTh2M6GUjo1uKsbZXqpuJ+Ny9VWv2JyaGlxNaSfZ2rTJv/Rev5+cVO3Gkmb5YsR1bLrvvu3bN9juxEQ43tx37bG0CnXWRchaAl2LYzet4qZ97LbLwLZ2bUt2cTFzjRw9qt6bLow//dN8/Y8+qlLdmnujmmjrff16d+Irl+Vr7xNq8vzzg086s7P5TayB/H6iqa4Rc/NqHWtvlytaFyGkGGMj7EB7P/TQUnXXZGqIgwcHJzalVKJs740KZOkAbLeIr20zSscl6rq9554bvMYbbwSuuSaLoV+3LmszNI9gE8uyWKQuQkhxuEApEdfiI8BtfW7frgRLR3iYETBa1PQkqLkhs94b1VwgpCNXVlbU4qRY2+bnerJ106Zsz1TN448PXuOOHWpf1rk5dZiWtB7cPvzh+IS1r19Atpfrb/2W2tCb0S6E1M9YWexNUWSfU5f1qV0XvgVLpkvJTilgfv/UU1kO9hMnslhxXZ/P8nVFm9x1V35HJTtW3bzmm25yX2vKE5O985HZr6WlfJ53e9AjhNREiiO+7mOUk4D5klLFJkvLTPLFJhEPHcqnytV52E0WFgZzoPv6FCpbR9It185HZl32BK0QzJlOSBHQtcnTprDdCK5FN6HFRFXacsWM2zHgdvihXob/8MOZNe9aGavfm8v+U/sRqs8sZ9YDDE726icJbbGbPnxCSH1Q2C1sNwZQbqGMb7FR0UlEVwy4rt+Mejl6NEsHYOdcMd+7BqZQP+yJUF992gVjxtvb16NXri4uqvfclJqQ4UBht7AjYAD3opsQphjqZfIrK4NCGFugY/bJFYmjRV2nA7j/fiWcZtrdO++MD0w6Ja+dUhcYtOZd9QH5673qKr9oM4SRkOFDYXdgi0/RXXlMMdThhlIOCuHRo0oIb7jBHZ/us2ztqJeXvAT47nfzIZI6ImfbNuWmCQ1MPpcOMGjNu+qLuWAIIc1CYU+gqJVpb5mnpwtNIdTW9uqqSqZliqkdPXLbbXlL2xbbP/zDcMpblyvHJORjd8Xwu+pjXDohowOFfQhoMVxcVEm3jh9XVvT+/ZkQ6k03APVqiqkWWs3x46ouU0xTxNbsT2hgivn67UyTdjuhBVyEkOahsBcgFN9uf6dj2VdWMj/48rIqu3mzcr9ce636zo52saNH+n2VREz76c1t8kxxLTqpW0SYtWvo9tsH5wuKtE8IaYCUmMi6j1GOY/cRivVOScLli1P3xb8fOqQ2hN60aXDD6n6/XMx5qJ+xOH17I+t+nzHohDQNGMdeLyE/tO+7mCUcs3LvvTfvktHoSUrXFnMhV4hvqX9sQwt9ns5xoydmq/jSi6zuJYQUg8KeSMgPHfqujsVLeneh1VUVJy6E+ly3FUu6FepnbHGSfV6/r3ZAqhKDntpfQkg5KOyJhKzvYUwehnYXAvJtmVvMxdLguvoZi2ip+/pSBhNCSHmEtHPINsDMzIw8fPhw4+3WRVNuhNR2qlrATV/P9HR8NSwhZBAhxBEp5Uy0HIW9GKPqRigyCLTh23alJmhrQ3JCxpVUYacrpiCj6kYwJ3LN9yZtDkr2fVteVvntCSH1Q2EvyKju/mNvuH3FFf5UBCmDUt2W/ajeN0K6SC3CLoT4AIC/BPBiKeW/1lHnqOBaeDSKqyxN0T5xAlhYUMnLTKs8VVyHYdmP6n0jpItUFnYhxNkAtgJ4qnp3RgufwI3iKkst2s8/ny1lev75bCNtIF1ch+VuGsX7RkgXqWPP008A2AWg+VnYIeNb0DOKaNG++uosJ7qUKgXA0lK+nGvvVhM9SOh9V+k2IWS8qGSxCyHeBOBpKeUTwt4teQSo6iceN7+waREvLGSbYKdY3Pa9otuEkPElKuxCiPsBnOH46joAfwzlhokihNgBYAcArF+/vkAXy1GHn3hcBW779mKbg4RcToSQ8SMq7FLK17k+F0KcC+CXAWhr/SwAjwohNkkpf+io5wCAA4CKY6/S6RTq8hNXyZrYFkUHpFEN4SSElKO0K0ZK+TUA/06/F0L8M4CZUYmKadKNMoqLlopY3OPmciKEhOlsHHuTbpRxt3jH1eVECHFTm7BLKTfUVVddNOUn7oLFS586Id2hsxZ7k9DiJYSMEhT2mqDFSwgZFepYoEQIIWSEoLATQkjHoLATQkjHoLATQkjHoLATQkjHoLATQkjHaGXPUyHEswC+13jDg5wOYCRSILQM7wPvgYb3QTGq9+EcKeWLY4VaEfZRQQhxOGVj2K7D+8B7oOF9UIz7faArhhBCOgaFnRBCOsZaF/YDbXdgROB94D3Q8D4oxvo+rGkfOyGEdJG1brETQkjnoLCfRAjxASGEFEKc3nZf2kAI8RdCiG8JIf5RCPG/hBCntd2nphBCvF4I8U9CiCeFEH/Udn/aQAhxthDiQSHEN4QQXxdCvK/tPrWFEKIvhHhMCPHFtvtSFgo71H9qqE25n2q7Ly1yH4BXSSn/E4BvA9jdcn8aQQjRB3AjgIsBbATwViHExnZ71QorAD4gpdwI4D8DuHaN3gcAeB+Ab7bdiSpQ2BWfALALwJqdcJBSfllKuXLy7VehNidfC2wC8KSU8rtSymMAPgfgTS33qXGklD+QUj568u//ByVsZ7bbq+YRQpwF4A0Abmm7L1VY88IuhHgTgKellE+03ZcR4koA97TdiYY4E8D/Nd5/H2tQ0EyEEBsAnA/g/7Tbk1bYD2XkrbbdkSqsiR2UhBD3AzjD8dV1AP4Yyg3TeUL3QUp598ky10E9ln+2yb6R0UAI8QsA7gSwU0r5k7b70yRCiEsB/EhKeUQIMdt2f6qwJoRdSvk61+dCiHMB/DKAJ4QQgHI/PCqE2CSl/GGDXWwE333QCCF+D8ClAC6SaycO9mkAZxvvzzr52ZpDCLEOStQ/K6X8Qtv9aYHXAHijEOISAKcAOFUI8T+klL/bcr8Kwzh2AyHEPwOYkVKOYvKfoSKEeD2AjwPYIqV8tu3+NIUQYgJqsvgiKEF/BMB/k1J+vdWONYxQls0dAH4spdzZdn/a5qTF/kEp5aVt96UMa97HTn7ODQBeBOA+IcTjQojPtN2hJjg5YfweAPdCTRh+fq2J+kleA+DtAP7LyX//x09armQMocVOCCEdgxY7IYR0DAo7IYR0DAo7IYR0DAo7IYR0DAo7IYR0DAo7IYR0DAo7IYR0DAo7IYR0jP8PMgpwlvBA5S4AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "import datetime as dt\n", "\n", "def generate_all_dataset():\n", " # --- Fake dataset ---\n", "\n", " np.random.seed(0)\n", "\n", " ntrain = 1000\n", " nvalid = 100\n", " ntest = 100\n", "\n", " mupos = np.array([2., 2.])\n", " sigmapos = np.array([[1., 0.], [0., 1.]])\n", " muneg = np.array([-2., -2.])\n", " sigmaneg = np.array([[.7, .2], [.2, .7]])\n", "\n", " def generate_a_dataset(mupos, sigmapos, muneg, sigmaneg, n):\n", " npos = int(n/2)\n", " nneg = n - npos\n", "\n", " Xpos = np.random.multivariate_normal(mupos, sigmapos, npos)\n", " Ypos = np.stack((np.ones((npos,)), np.zeros((npos,))), axis=1)\n", "\n", " Xneg = np.random.multivariate_normal(muneg, sigmaneg, nneg)\n", " Yneg = np.stack((np.zeros((nneg,)), np.ones((nneg,))), axis=1)\n", "\n", " X, Y = np.concatenate((Xpos, Xneg)), np.concatenate((Ypos, Yneg))\n", "\n", " idx = np.arange(n)\n", " np.random.shuffle(idx)\n", " X, Y = X[idx], Y[idx]\n", "\n", " return np.array(X, dtype='float32'), np.array(Y, dtype='float32')\n", "\n", " Xtrain, Ytrain = generate_a_dataset(\n", " mupos, sigmapos, muneg, sigmaneg, ntrain)\n", " Xvalid, Yvalid = generate_a_dataset(\n", " mupos, sigmapos, muneg, sigmaneg, nvalid)\n", " Xtest, Ytest = generate_a_dataset(mupos, sigmapos, muneg, sigmaneg, ntest)\n", "\n", " return (Xtrain, Ytrain), (Xvalid, Yvalid), (Xtest, Ytest)\n", "\n", "\n", "def plot_dataset(X, Y):\n", " plt.figure()\n", " idpos, = np.nonzero(Y[:, 0] == 1.)\n", " idneg, = np.nonzero(Y[:, 1] == 1.)\n", " plt.plot(X[idpos, 0], X[idpos, 1], 'r+')\n", " plt.plot(X[idneg, 0], X[idneg, 1], 'b.')\n", " plt.show()\n", "\n", "\n", "def demo(trainset, validset, testset):\n", " plot_dataset(*trainset)\n", "\n", "\n", "demo(*generate_all_dataset())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise\n", "\n", "Inspired by the precedent object oriented verision of the linear regression, write a logistic regression model:\n", "\n", "$$\\hat{y} = softmax(+b)$$\n", "\n", "It can be learnt by the cross-entropy loss:\n", "\n", "$$ \\mathcal{L}(y,\\hat{y}) = -y\\log\\hat{y} - (1-y)\\log(1-\\hat{y})$$\n", "\n", "In fact, that's a **one layer perceptron** !\n", "\n", "You will have to use the following operations:\n", "- `tf.matmul`\n", "- `tf.nn.softmax`\n", "- `tf.losses.softmax_cross_entropy`\n", "\n", "Please note that the `softmax_cross_entropy` loss takes as input the logits, i.e. the output of the linear model before the softmax, and that both `softmax` and `softmax_cross_entropy` need a two column input (one per class).\n", "\n", "Guidelines:\n", "\n", "1) Copy the code of `LinearRegressionV3` class and rename the class to `LogisticRegressionV1` \n", "\n", "2) Modify `_build_network` method to have a logistic regression model and a cross-entropy loss\n", "\n", "3) Add a `_compute_pred_loss(self, set_iterator_init, nbatches)` method that return the prediction and the loss\n", "\n", "4) Modify `train` method by using `_compute_pred_loss` on the test set, and make `train` return the prediction of the test set\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# object oriented verision of the linear regression\n", "class LinearRegression:\n", "\n", " def __init__(self, dtype, ninputs, noutputs,\n", " learning_rate=5e-3, training_epochs=100, batchsize=50):\n", "\n", " self.nc = ninputs\n", " self.no = noutputs\n", "\n", " self.dt_shapes = (tf.TensorShape((None, ninputs)),\n", " tf.TensorShape((None, noutputs)))\n", " self.dt_types = (dtype, dtype)\n", "\n", " self.learning_rate = learning_rate\n", " self.training_epochs = training_epochs\n", " self.batchsize = batchsize\n", "\n", " self._build_network()\n", " \n", " # Global variable initializer\n", " self.init_global = tf.initializers.global_variables()\n", "\n", " def _build_network(self):\n", " # Create ONLY ONE iterator base on types and shapes of one of the dataset\n", " # both dataset should have the same types and shapes...\n", " self.dataset_iterator = tf.data.Iterator.from_structure(\n", " self.dt_types, self.dt_shapes)\n", "\n", " # Placeholders from the dataset iterator\n", " x, y = self.dataset_iterator.get_next()\n", "\n", " # Linear regression parameters\n", " self.A = tf.Variable(tf.zeros([self.nc, self.no]))\n", " self.b = tf.Variable(tf.zeros([self.no]))\n", " # All parameters are gathered into var_list\n", " var_list = [self.A, self.b]\n", "\n", " # Actual linear regression\n", " self.pred = tf.matmul(x, self.A) + self.b\n", "\n", " # Model loss\n", " self.loss = tf.losses.mean_squared_error(y, self.pred)\n", "\n", " self.optimizer = tf.train.GradientDescentOptimizer(\n", " self.learning_rate).minimize(self.loss, var_list=var_list)\n", "\n", " def _prepareset(self, dataset, shuffle=True):\n", " if shuffle:\n", " dataset = dataset.shuffle(buffer_size=1000)\n", " dataset = dataset.batch(self.batchsize)\n", " return dataset\n", "\n", " def _compute_loss(self, set_iterator_init, nbatches):\n", " self.session.run(set_iterator_init)\n", " lossval = 0.\n", " for b in range(nbatches):\n", " batch_lossval, = self.session.run([self.loss])\n", " lossval += batch_lossval\n", " lossval /= nbatches\n", " return lossval\n", "\n", " def _compute_gradient_step(self, set_iterator_init, nbatches):\n", " self.session.run(set_iterator_init)\n", " lossval = 0.\n", " for b in range(nbatches):\n", " _, batch_lossval, = self.session.run([self.optimizer, self.loss])\n", " lossval += batch_lossval\n", " lossval /= nbatches\n", " return lossval\n", "\n", " def train(self, trainset, validset, testset):\n", "\n", " (Xtrain, Ytrain) = trainset\n", " (Xvalid, Yvalid) = validset\n", " (Xtest, Ytest) = testset\n", "\n", " # --- Linear Regression ---\n", "\n", " ntrain, _ = Xtrain.shape\n", " ntest, _ = Xtest.shape\n", " nvalid, _ = Xvalid.shape\n", "\n", " trainset = self._prepareset(\n", " tf.data.Dataset.from_tensor_slices((Xtrain, Ytrain)))\n", " testset = self._prepareset(\n", " tf.data.Dataset.from_tensor_slices((Xtest, Ytest)), shuffle=False)\n", " validset = self._prepareset(\n", " tf.data.Dataset.from_tensor_slices((Xvalid, Yvalid)), shuffle=False)\n", "\n", " ntrainbatches = int(np.ceil(ntrain/self.batchsize))\n", " ntestbatches = int(np.ceil(ntest/self.batchsize))\n", " nvalidbatches = int(np.ceil(nvalid/self.batchsize))\n", "\n", " # Create one initializer per dataset\n", " training_init_op = self.dataset_iterator.make_initializer(trainset)\n", " test_init_op = self.dataset_iterator.make_initializer(testset)\n", " validation_init_op = self.dataset_iterator .make_initializer(validset)\n", "\n", "\n", " with tf.Session() as self.session:\n", "\n", " # We call the initialization of A and b\n", " self.session.run(self.init_global)\n", "\n", " # We compute the train and validation loss\n", " # Note that you just have to change the feed_dict to change the set\n", "\n", " trainloss = self._compute_loss(training_init_op, ntrainbatches)\n", " validationloss = self._compute_loss(\n", " validation_init_op, nvalidbatches)\n", " print(\"Init\\t\\t train loss %f\\t valid loss %f\" %\n", " (trainloss, validationloss))\n", "\n", " # We cycle on epochs\n", " for epoch in range(self.training_epochs):\n", "\n", " trainloss = self._compute_gradient_step(\n", " training_init_op, ntrainbatches)\n", " validationloss = self._compute_loss(\n", " validation_init_op, nvalidbatches)\n", " print(\"Epoch %03d\\t train loss %f\\t valid loss %f\" %\n", " (epoch+1, trainloss, validationloss))\n", "\n", " # We compute the test loss\n", " testloss = self._compute_loss(test_init_op, ntestbatches)\n", " print(\"Test loss %f\" % (testloss,))\n", "\n", " # Found parameters\n", " Aval, bval = self.session.run([self.A, self.b])\n", " print(\"Estimated A\\n\", Aval)\n", " print('Estimated b\\n', bval)\n", " # Here session is closed automatically" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Metrics\n", "\n", "In the precedent script, each time you compute the loss of batch, the result is copied from the memory of the computation engine back to the computer memory. Moreover, it is actually the computer which is doing the averaging.\n", "\n", "You can faster the loss loop/computation by using `metrics`: that's a trick to do all the work inside the memory of the computation engine.\n", "\n", "A metric contain an internal state that will be update for each batch. When all the examples have been seen , the last internal state is used to compute the actual metrics.\n", "\n", "For example, a mean metrics consists in a `total` and `count` internal states.\n", "At each update, the targeted tensor (for example the loss) is added to `total` and `count` is incremented by one.\n", "At the end, `total` is divided by `count` to give the actual mean metric.\n", "\n", "### How to use a metric ?\n", "\n", "1) Create a metric on targeted tensor(s) (like the loss), you will be provide by two tensors in return `actualmetric`, `metricupdate`\n", "\n", "2) Initialize the internal state of the metric\n", "\n", "3) Update the internal state of the metric running `metricupdate` for each batch inside a loop.\n", "\n", "4) Get the metric value by running `actualmetric`.\n", "\n", "\n", "\n", "### Example\n", "\n", "Here is how you can use `tf.metrics.mean` for averaging the loss over all the batches to compute a loss\n", "\n", "```python\n", "class LogisticRegressionV2(LogisticRegressionV1):\n", "\n", " def __init__(self,*args,**kwargs):\n", " super().__init__(*args,**kwargs)\n", " \n", " # Initializers of metrics,\n", " # should be instantiate after the network is built\n", " self.init_local = tf.initializers.local_variables()\n", " \n", " def _build_network(self):\n", " super()._build_network()\n", "\n", " # Metrics\n", " self.lossmetric, self.lossmetric_update = tf.metrics.mean(self.loss)\n", "\n", " def _compute_loss(self, set_iterator_init, nbatches):\n", " # Initialize all the metrics\n", " self.session.run(self.init_local)\n", " # Initialize the iterator\n", " self.session.run(set_iterator_init)\n", " # Loop\n", " for b in range(nbatches):\n", " self.session.run(self.lossmetric_update)\n", " lossval = self.session.run(self.lossmetric)\n", " return lossval\n", "```\n", "\n", "### Exercise\n", "\n", "Complete the class `LogisticRegressionV2` by methods `_compute_gradient_step` and `_compute_pred_loss` that compute the loss through a metric.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tracking the accuracy\n", "\n", "Metrics are also useful to track indicators other than the loss.\n", "For example in our classification task we can look at the accuracy of the model, i.e. the matches between prediction and labels.\n", "\n", "Here is a code for computing the loss and the accuracy at the same time:\n", "\n", "```python\n", "class LogisticRegressionV3(LogisticRegressionV2):\n", "\n", " def _build_network(self):\n", " super()._build_network()\n", "\n", " self.accuracymetric, self.accuracymetric_update = tf.metrics.accuracy(\n", " self.y, self.pred)\n", "\n", " def _compute_loss(self, set_iterator_init, nbatches):\n", " # Initialize all the metrics\n", " self.session.run(self.init_local)\n", " # Initialize the iterator\n", " self.session.run(set_iterator_init)\n", " # Loop\n", " for b in range(nbatches):\n", " self.session.run(\n", " [self.lossmetric_update, self.accuracymetric_update])\n", " return self.session.run([self.lossmetric, self.accuracymetric])\n", "```\n", "\n", "Please refer to this page to see other possible metrics:\n", "https://www.tensorflow.org/api_docs/python/tf/metrics\n", "\n", "### Excercice\n", "\n", "Complete the class `LogisticRegressionV3` in order to have a method `_compute_pred_loss` that compute prediction, loss and accuracy, and a `train` method that displays the accuracy.\n", "Use a low learning rate (1e-10) if you want to see the evolution of accuracy as it is an easy dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensorboard\n", "\n", "Tensorboard is an easy way to monitor the training of your model.\n", "It is a dashboard in a web browser displaying metrics and possibly parameters of your model.\n", "\n", "In order to do so you first need to export your metrics to log files.\n", "Tensorboard will then explore the log files and show you the different graphs in your browser.\n", "\n", "### Summaries\n", "\n", "Summaries are object that can be evaluated and written to tensorboard log files.\n", "You need to wrap your metrics around summaries.\n", "\n", "After defining a metric just do:\n", "\n", "```python\n", "trainlosssummary = tf.summary.scalar(\"train_loss\",lossmetric)\n", "```\n", "\n", "You can merge multiple summaries into one : \n", "\n", "```python\n", "trainlosssummary = tf.summary.scalar(\"train_loss\",lossmetric)\n", "trainaccuracysummary = tf.summary.scalar(\"train_accuracy\",accuracymetric)\n", "trainsummaries = tf.summary.merge([trainlosssummary,trainaccuracysummary ])\n", "```\n", "\n", "Now `trainsummaries` contains all the summaries of the trainset.\n", "\n", "If you have multiple set (e.g. train and valid), you should create one summary per set, even if they are based on the same metric:\n", "\n", "```python\n", "trainlosssummary = tf.summary.scalar(\"train_loss\",lossmetric)\n", "trainaccuracysummary = tf.summary.scalar(\"train_accuracy\",accuracymetric)\n", "trainsummaries = tf.summary.merge([trainlosssummary,trainaccuracysummary ])\n", "\n", "validationlosssummary = tf.summary.scalar(\"validation_loss\",lossmetric)\n", "validationaccuracysummary = tf.summary.scalar(\"validation_accuracy\",accuracymetric)\n", "validationsummaries = tf.summary.merge([validationlosssummary,validationaccuracysummary ])\n", "```\n", "\n", "### Writer\n", "\n", "A Writer is an object that represent a recorder to a folder.\n", "\n", "To instanciate a recorder do inside a session:\n", "```python\n", "with tf.Session() as session:\n", " #[...]\n", " logdir = \"logs/somename/\" + dt.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", " writer = tf.summary.FileWriter(logdir)\n", " #[...]\n", "```\n", "\n", "Now at each epoch, you can evaluate the summaries and add them to the recorder:\n", "\n", "```python\n", "with tf.Session() as session:\n", "\n", " # We call the global initialization \n", " session.run(init_global)\n", " # Create the log writer\n", " logdir=\"logs/logisticregression/\" + dt.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", " writer = tf.summary.FileWriter(logdir)\n", " #[...]\n", " for epoch in range(training_epochs):\n", " # Work on the training set\n", " #[...]\n", " # And then\n", " trainsummariesval = session.run(trainsummaries) # evaluate the summary\n", " writer.add_summary(trainsummariesval, epoch+1) # record the summary\n", " # Work on the validation set\n", " #[...]\n", " # And then\n", " validationsummariesval = session.run(validationsummaries) # evaluate the summary\n", " writer.add_summary(validationsummariesval,epoch+1) # record the summary\n", "```\n", "\n", "### Log analyzing\n", "\n", "A `logisticregressionlogs` folder will be create where the script is run.\n", "It contains log file describing the evolution of the training.\n", "\n", "To analyze them, we can actually launch tensorboard with the command inside a bash terminal:\n", "\n", "```bash\n", "tensorboard --logdir=logisticregressionlogs\n", "```\n", "It will automatically launch a web browser where you can monitor the metrics of your model.\n", "\n", "**Caution !** as summaries and log files cumulate it is highly recommended to restart the python kernel and to erase the log folder before each execution of the script.\n", "\n", "### Your turn !\n", "\n", "1) Create a `LogisticRegressionV4` class derivated from `LogisticRegressionV3`\n", "\n", "2) Modify `_build_network`and `train` methods to support file logging.\n", "\n", "3) Analyze the log in Tensorboard\n", "\n", "4) Create a `LogisticRegression` class (no subclassing) merging all the logistic regression versions.\n" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }