|
34 | 34 | }, |
35 | 35 | { |
36 | 36 | "cell_type": "code", |
37 | | - "execution_count": 4, |
| 37 | + "execution_count": 30, |
38 | 38 | "metadata": { |
39 | 39 | "autoscroll": false, |
40 | 40 | "collapsed": true, |
|
50 | 50 | }, |
51 | 51 | { |
52 | 52 | "cell_type": "code", |
53 | | - "execution_count": 5, |
| 53 | + "execution_count": 31, |
54 | 54 | "metadata": { |
55 | 55 | "autoscroll": false, |
56 | 56 | "collapsed": true, |
|
335 | 335 | ] |
336 | 336 | }, |
337 | 337 | { |
338 | | - "attachments": {}, |
339 | 338 | "cell_type": "markdown", |
340 | 339 | "metadata": {}, |
341 | 340 | "source": [ |
|
353 | 352 | ] |
354 | 353 | }, |
355 | 354 | { |
356 | | - "attachments": {}, |
357 | 355 | "cell_type": "markdown", |
358 | 356 | "metadata": {}, |
359 | 357 | "source": [ |
|
396 | 394 | { |
397 | 395 | "cell_type": "code", |
398 | 396 | "execution_count": 31, |
399 | | - "metadata": {}, |
| 397 | + "metadata": { |
| 398 | + "collapsed": true |
| 399 | + }, |
400 | 400 | "outputs": [], |
401 | 401 | "source": [ |
402 | 402 | "# 产生一组数据,模型不知道数据怎么产生的。\n", |
|
432 | 432 | }, |
433 | 433 | { |
434 | 434 | "cell_type": "code", |
435 | | - "execution_count": 59, |
| 435 | + "execution_count": 8, |
436 | 436 | "metadata": { |
437 | 437 | "collapsed": true |
438 | 438 | }, |
|
491 | 491 | { |
492 | 492 | "cell_type": "code", |
493 | 493 | "execution_count": 56, |
494 | | - "metadata": {}, |
| 494 | + "metadata": { |
| 495 | + "collapsed": true |
| 496 | + }, |
495 | 497 | "outputs": [], |
496 | 498 | "source": [ |
497 | 499 | "dm = Discriminator(2)\n", |
|
584 | 586 | { |
585 | 587 | "cell_type": "code", |
586 | 588 | "execution_count": 101, |
587 | | - "metadata": {}, |
| 589 | + "metadata": { |
| 590 | + "collapsed": true |
| 591 | + }, |
588 | 592 | "outputs": [], |
589 | 593 | "source": [ |
590 | 594 | "criterion = nn.BCELoss() # binary cross entropy loss\n", |
|
602 | 606 | { |
603 | 607 | "cell_type": "code", |
604 | 608 | "execution_count": 146, |
605 | | - "metadata": {}, |
| 609 | + "metadata": { |
| 610 | + "collapsed": true |
| 611 | + }, |
606 | 612 | "outputs": [], |
607 | 613 | "source": [ |
608 | 614 | "num_epochs = 2000\n", |
|
688 | 694 | { |
689 | 695 | "cell_type": "code", |
690 | 696 | "execution_count": 157, |
691 | | - "metadata": {}, |
| 697 | + "metadata": { |
| 698 | + "collapsed": true |
| 699 | + }, |
692 | 700 | "outputs": [], |
693 | 701 | "source": [ |
694 | 702 | "gm_init = Generator(2, 2)\n", |
|
698 | 706 | { |
699 | 707 | "cell_type": "code", |
700 | 708 | "execution_count": 158, |
701 | | - "metadata": {}, |
| 709 | + "metadata": { |
| 710 | + "collapsed": true |
| 711 | + }, |
702 | 712 | "outputs": [], |
703 | 713 | "source": [ |
704 | 714 | "g_input_data = g_input.data.numpy()\n", |
|
732 | 742 | " ax.legend()" |
733 | 743 | ] |
734 | 744 | }, |
| 745 | + { |
| 746 | + "cell_type": "markdown", |
| 747 | + "metadata": { |
| 748 | + "collapsed": true |
| 749 | + }, |
| 750 | + "source": [ |
| 751 | + "## Conditional GANs\n", |
| 752 | + "以 mnist dataset 为例,解释 Conditional GANs" |
| 753 | + ] |
| 754 | + }, |
| 755 | + { |
| 756 | + "cell_type": "markdown", |
| 757 | + "metadata": {}, |
| 758 | + "source": [ |
| 759 | + "### MNIST 数据集读取\n", |
| 760 | + "- 一个著名的手写数字数据集" |
| 761 | + ] |
| 762 | + }, |
735 | 763 | { |
736 | 764 | "cell_type": "code", |
737 | | - "execution_count": null, |
| 765 | + "execution_count": 77, |
738 | 766 | "metadata": { |
739 | 767 | "collapsed": true |
740 | 768 | }, |
741 | 769 | "outputs": [], |
| 770 | + "source": [ |
| 771 | + "import torchvision.datasets as dset\n", |
| 772 | + "import torchvision.transforms as transforms\n", |
| 773 | + "import os" |
| 774 | + ] |
| 775 | + }, |
| 776 | + { |
| 777 | + "cell_type": "code", |
| 778 | + "execution_count": 78, |
| 779 | + "metadata": {}, |
| 780 | + "outputs": [], |
| 781 | + "source": [ |
| 782 | + "if not os.path.exists('./data'):\n", |
| 783 | + " os.mkdir('./data')\n", |
| 784 | + " \n", |
| 785 | + "train_set = dset.MNIST(root='./data', train=True, \n", |
| 786 | + " transform=transforms.Compose([transforms.ToTensor(), \n", |
| 787 | + " transforms.Normalize((0.5,), (1.0,))]),\n", |
| 788 | + " download=True)" |
| 789 | + ] |
| 790 | + }, |
| 791 | + { |
| 792 | + "cell_type": "code", |
| 793 | + "execution_count": 70, |
| 794 | + "metadata": {}, |
| 795 | + "outputs": [], |
| 796 | + "source": [ |
| 797 | + "dataloader = torch.utils.data.DataLoader(train_set, batch_size=200, shuffle=True)" |
| 798 | + ] |
| 799 | + }, |
| 800 | + { |
| 801 | + "cell_type": "markdown", |
| 802 | + "metadata": {}, |
| 803 | + "source": [ |
| 804 | + "Take a look at the data" |
| 805 | + ] |
| 806 | + }, |
| 807 | + { |
| 808 | + "cell_type": "code", |
| 809 | + "execution_count": 71, |
| 810 | + "metadata": {}, |
| 811 | + "outputs": [], |
| 812 | + "source": [ |
| 813 | + "data_sample = next(iter(dataloader))" |
| 814 | + ] |
| 815 | + }, |
| 816 | + { |
| 817 | + "cell_type": "markdown", |
| 818 | + "metadata": {}, |
| 819 | + "source": [ |
| 820 | + "- dataloader 迭代返回的每个元素是一个 list, 其第一个元素是这批 batch 的图像,第二个元素是这批 batch 的 label\n", |
| 821 | + "- 第一个元素是一个 FloatTensor, 大小为 (200 x 1 x 28 x 28),即 (batch_size x channel x height x width)。因为是灰度图像,所以只有一个 channel\n", |
| 822 | + "- 因为 channel 这个维度没用,所以我们用 view 把它消解" |
| 823 | + ] |
| 824 | + }, |
| 825 | + { |
| 826 | + "cell_type": "code", |
| 827 | + "execution_count": 72, |
| 828 | + "metadata": {}, |
| 829 | + "outputs": [ |
| 830 | + { |
| 831 | + "name": "stdout", |
| 832 | + "output_type": "stream", |
| 833 | + "text": [ |
| 834 | + "torch.Size([200, 1, 28, 28])\n", |
| 835 | + "torch.Size([200, 28, 28])\n" |
| 836 | + ] |
| 837 | + } |
| 838 | + ], |
| 839 | + "source": [ |
| 840 | + "print(data_sample[0].size())\n", |
| 841 | + "print(data_sample[0].view(-1, 28, 28).size())" |
| 842 | + ] |
| 843 | + }, |
| 844 | + { |
| 845 | + "cell_type": "code", |
| 846 | + "execution_count": 74, |
| 847 | + "metadata": {}, |
| 848 | + "outputs": [ |
| 849 | + { |
| 850 | + "data": { |
| 851 | + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADTpJREFUeJzt3W+oVHUex/HPV82blJGurEilpolRQqaX8sFladlV3AhM\ngsieuGzs7UGRxUaG+2CLCCSsZetBoCXa4pbbP7Ja1lIWbWEVzdQs13Triop5EwW1B2X63Qdz7u7V\n7vxmnDkzZ8bv+wWXO3O+c858Gf3cc878Zs7P3F0A4hlUdAMAikH4gaAIPxAU4QeCIvxAUIQfCIrw\nA0ERfiAowg8ENaSZT2ZmfJwQaDB3t2oeV9ee38xmm9keM9tnZo/Xsy0AzWW1frbfzAZL+kLSTEkH\nJW2RNM/dP0+sw54faLBm7PlvkbTP3b909+8lvSZpTh3bA9BE9YT/KkkH+t0/mC07h5l1m9lWM9ta\nx3MByFnD3/Bz96WSlkoc9gOtpJ49/yFJ1/S7f3W2DEAbqCf8WyRNMrNrzWyopHskrcmnLQCNVvNh\nv7v/YGYPSlorabCk5e7+WW6dAWiomof6anoyzvmBhmvKh3wAtC/CDwRF+IGgCD8QFOEHgiL8QFCE\nHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCauqlu1Gbjo6OZP2TTz4pWxsyJP1PPG3atGT91KlTyTra\nF3t+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf42MH369GT9+uuvr3nbw4YNS9YZ5794secHgiL8\nQFCEHwiK8ANBEX4gKMIPBEX4gaDqGuc3sx5JJyWdkfSDu3fm0RTONWgQf6ORvzw+5PNzdz+aw3YA\nNBG7FCCoesPvkj4ws4/NrDuPhgA0R72H/V3ufsjMfirpQzP7t7tv7P+A7I8CfxiAFlPXnt/dD2W/\neyW9LemWAR6z1N07eTMQaC01h9/MLjOz4X23Jc2StCuvxgA0Vj2H/aMlvW1mfdv5i7v/PZeuADRc\nzeF39y8l3ZRjLyijq6urYdu+6ab0P+G6desa9twoFkN9QFCEHwiK8ANBEX4gKMIPBEX4gaC4dHcb\n2LFjR1tuG62NPT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4fxuo9LXblDNnziTr7l7zttHe2PMD\nQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM81/k3njjjWT96FEmWI6KPT8QFOEHgiL8QFCEHwiK8ANB\nEX4gKMIPBFVxnN/Mlku6Q1Kvu0/Jlo2UtFrSeEk9ku529+ONaxPt6tZbby1bW7JkSXLdxYsXJ+vv\nv/9+TT2hpJo9/wpJs89b9rik9e4+SdL67D6ANlIx/O6+UdKx8xbPkbQyu71S0p059wWgwWo95x/t\n7oez219LGp1TPwCapO7P9ru7m1nZC8GZWbek7nqfB0C+at3zHzGzMZKU/e4t90B3X+rune7eWeNz\nAWiAWsO/RtL87PZ8Se/k0w6AZqkYfjN7VdK/JE02s4Nmdp+kxZJmmtleSb/M7gNoI9bM67an3htA\neXv37k3WJ06cWLY2a9as5Lrr1q2rqac+Q4cOTdbffffdsrWZM2cm1/3uu++S9dmzzx+BPteGDRuS\n9YuVu1s1j+MTfkBQhB8IivADQRF+ICjCDwRF+IGguHR3G7j00kuT9f3795et7dy5M+92zvHQQw8l\n65WG81I6OjqS9RkzZiTrUYf6qsWeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeC4iu9LWDs2LHJ+vbt\n25P1r776qmxt+vTpNfXUp6urK1lfsWJFsj5hwoSytT179iTXnTx5crK+bdu2ZL2zM+bFo/hKL4Ak\nwg8ERfiBoAg/EBThB4Ii/EBQhB8Iiu/zt4BKY/FXXnllw5570KD03/8FCxYk66lxfEk6cOBA2dq0\nadOS627atClZv+KKK5J1pLHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKo7zm9lySXdI6nX3Kdmy\nJyT9VtI32cMWufvfGtUkGueRRx5J1u+6665k/ZtvvknW58yZU7ZWaXrvYcOGJetbtmxJ1pFWzZ5/\nhaSBJkL/o7tPzX4IPtBmKobf3TdKOtaEXgA0UT3n/A+a2U4zW25mI3LrCEBT1Br+FyVNlDRV0mFJ\nz5Z7oJl1m9lWM9ta43MBaICawu/uR9z9jLuflbRM0i2Jxy519053j3k1RaBF1RR+MxvT7+5cSbvy\naQdAs1Qz1PeqpNskjTKzg5L+IOk2M5sqySX1SLq/gT0CaICK4Xf3eQMsfrkBvaAACxcurGv9tWvX\nJuupOQeuu+665LojR45M1k+fPp2sI41P+AFBEX4gKMIPBEX4gaAIPxAU4QeC4tLdF4HUV2NfeOGF\n5LqjRo1K1isNpz3//PPJ+uDBg8vW7r333uS6I0akvzIyY8aMZB1p7PmBoAg/EBThB4Ii/EBQhB8I\nivADQRF+ICjG+VvA5s2bk/Xjx48n6zfeeGNNtWqcOnUqWZ80aVKyvmjRorK1uXPnJtd192T9qaee\nStaRxp4fCIrwA0ERfiAowg8ERfiBoAg/EBThB4KySmOpuT6ZWfOe7CLyzDPPJOuPPvpokzpprkrj\n+E8++WSyfvbs2TzbaRvubtU8jj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRVcZzfzK6R9Iqk0ZJc\n0lJ3/5OZjZS0WtJ4ST2S7nb35BfPGeevTera95I0ZcqUsrVNmzYl1+3o6Kippz7ffvttst7T01O2\n9vTTTyfXXb16dbLezM+otJM8x/l/kPQ7d79B0gxJD5jZDZIel7Te3SdJWp/dB9AmKobf3Q+7+7bs\n9klJuyVdJWmOpJXZw1ZKurNRTQLI3wWd85vZeEk3S9osabS7H85KX6t0WgCgTVR9DT8zu1zSm5Ie\ndvcTZv8/rXB3L3c+b2bdkrrrbRRAvqra85vZJSoFf5W7v5UtPmJmY7L6GEm9A63r7kvdvdPdO/No\nGEA+KobfSrv4lyXtdvfn+pXWSJqf3Z4v6Z382wPQKNUM9XVJ+kjSp5L6viO5SKXz/r9KGitpv0pD\nfccqbIuxmSZbtWpVsj5v3ry6tv/YY48l60uWLKlr+7hw1Q71VTznd/d/Siq3sV9cSFMAWgef8AOC\nIvxAUIQfCIrwA0ERfiAowg8ExaW7L3LDhw9P1seNG5esv/7668n6iRMnkvVly5aVrb300kvJdVEb\nLt0NIInwA0ERfiAowg8ERfiBoAg/EBThB4JinB+4yDDODyCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeC\nIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4KqGH4zu8bM/mFmn5vZZ2a2IFv+hJkdMrPt2c/t\njW8XQF4qXszDzMZIGuPu28xsuKSPJd0p6W5Jp9x9SdVPxsU8gIar9mIeQ6rY0GFJh7PbJ81st6Sr\n6msPQNEu6JzfzMZLulnS5mzRg2a208yWm9mIMut0m9lWM9taV6cAclX1NfzM7HJJGyQ97e5vmdlo\nSUcluaSnVDo1+E2FbXDYDzRYtYf9VYXfzC6R9J6kte7+3AD18ZLec/cpFbZD+IEGy+0CnmZmkl6W\ntLt/8LM3AvvMlbTrQpsEUJxq3u3vkvSRpE8lnc0WL5I0T9JUlQ77eyTdn705mNoWe36gwXI97M8L\n4Qcaj+v2A0gi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBFXx\nAp45Oyppf7/7o7JlrahVe2vVviR6q1WevY2r9oFN/T7/j57cbKu7dxbWQEKr9taqfUn0VquieuOw\nHwiK8ANBFR3+pQU/f0qr9taqfUn0VqtCeiv0nB9AcYre8wMoSCHhN7PZZrbHzPaZ2eNF9FCOmfWY\n2afZzMOFTjGWTYPWa2a7+i0baWYfmtne7PeA06QV1FtLzNycmFm60Neu1Wa8bvphv5kNlvSFpJmS\nDkraImmeu3/e1EbKMLMeSZ3uXviYsJn9TNIpSa/0zYZkZs9IOubui7M/nCPcfWGL9PaELnDm5gb1\nVm5m6V+rwNcuzxmv81DEnv8WSfvc/Ut3/17Sa5LmFNBHy3P3jZKOnbd4jqSV2e2VKv3naboyvbUE\ndz/s7tuy2ycl9c0sXehrl+irEEWE/ypJB/rdP6jWmvLbJX1gZh+bWXfRzQxgdL+Zkb6WNLrIZgZQ\ncebmZjpvZumWee1qmfE6b7zh92Nd7j5N0q8kPZAd3rYkL52ztdJwzYuSJqo0jdthSc8W2Uw2s/Sb\nkh529xP9a0W+dgP0VcjrVkT4D0m6pt/9q7NlLcHdD2W/eyW9rdJpSis50jdJava7t+B+/sfdj7j7\nGXc/K2mZCnztspml35S0yt3fyhYX/toN1FdRr1sR4d8iaZKZXWtmQyXdI2lNAX38iJldlr0RIzO7\nTNIstd7sw2skzc9uz5f0ToG9nKNVZm4uN7O0Cn7tWm7Ga3dv+o+k21V6x/8/kn5fRA9l+pogaUf2\n81nRvUl6VaXDwNMqvTdyn6SfSFovaa+kdZJGtlBvf1ZpNuedKgVtTEG9dal0SL9T0vbs5/aiX7tE\nX4W8bnzCDwiKN/yAoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1X4+CPOHbV0x+AAAAAElFTkSu\nQmCC\n", |
| 852 | + "text/plain": [ |
| 853 | + "<matplotlib.figure.Figure at 0x1114ba128>" |
| 854 | + ] |
| 855 | + }, |
| 856 | + "metadata": {}, |
| 857 | + "output_type": "display_data" |
| 858 | + } |
| 859 | + ], |
| 860 | + "source": [ |
| 861 | + "plt.imshow(data_sample[0].view(-1, 28, 28)[0].numpy(), cmap='gray');" |
| 862 | + ] |
| 863 | + }, |
| 864 | + { |
| 865 | + "cell_type": "markdown", |
| 866 | + "metadata": {}, |
| 867 | + "source": [ |
| 868 | + "对于这样一个数据集,我们想用 GANs 的 Generator 自己产生对应的数字。如果用原始的 GANs 形式,需要针对 0-9 这 10 个数字产生对应 10 组模型。对更多类的问题,需要的模型会更多。而且各个模型之间互相独立,没有办法「共享」一部分他们学习的知识,这不是一个很好的解决方案。\n", |
| 869 | + "\n", |
| 870 | + "针对这个问题,Conditional GANs [1] 给出的方案是除了噪音输入 Z,给将每一个 $y$ 的信息也输入给 DM 及 GM(比如 embedding 形式)。其模型示意图(取自 [1])及问题形式如下:\n", |
| 871 | + "\n", |
| 872 | + "\n", |
| 873 | + "\n", |
| 874 | + "$$ \\min_G\\max_D V(D, G) = E_{x \\sim P_{data}(x|y)}[\\log D(x)] + E_{z \\sim P_z(z)}[\\log (1 - D(G(z|y))]$$" |
| 875 | + ] |
| 876 | + }, |
| 877 | + { |
| 878 | + "cell_type": "markdown", |
| 879 | + "metadata": {}, |
742 | 880 | "source": [] |
743 | 881 | } |
744 | 882 | ], |
|
0 commit comments