Skip to content

Commit 440ff2d

Browse files
committed
add Conditional GANs section
1 parent 4f20dc7 commit 440ff2d

File tree

2 files changed

+150
-12
lines changed

2 files changed

+150
-12
lines changed

imgs/mirza-arxiv-14.png

43.8 KB
Loading

learn-gan.ipynb

Lines changed: 150 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
},
3535
{
3636
"cell_type": "code",
37-
"execution_count": 4,
37+
"execution_count": 30,
3838
"metadata": {
3939
"autoscroll": false,
4040
"collapsed": true,
@@ -50,7 +50,7 @@
5050
},
5151
{
5252
"cell_type": "code",
53-
"execution_count": 5,
53+
"execution_count": 31,
5454
"metadata": {
5555
"autoscroll": false,
5656
"collapsed": true,
@@ -335,7 +335,6 @@
335335
]
336336
},
337337
{
338-
"attachments": {},
339338
"cell_type": "markdown",
340339
"metadata": {},
341340
"source": [
@@ -353,7 +352,6 @@
353352
]
354353
},
355354
{
356-
"attachments": {},
357355
"cell_type": "markdown",
358356
"metadata": {},
359357
"source": [
@@ -396,7 +394,9 @@
396394
{
397395
"cell_type": "code",
398396
"execution_count": 31,
399-
"metadata": {},
397+
"metadata": {
398+
"collapsed": true
399+
},
400400
"outputs": [],
401401
"source": [
402402
"# 产生一组数据,模型不知道数据怎么产生的。\n",
@@ -432,7 +432,7 @@
432432
},
433433
{
434434
"cell_type": "code",
435-
"execution_count": 59,
435+
"execution_count": 8,
436436
"metadata": {
437437
"collapsed": true
438438
},
@@ -491,7 +491,9 @@
491491
{
492492
"cell_type": "code",
493493
"execution_count": 56,
494-
"metadata": {},
494+
"metadata": {
495+
"collapsed": true
496+
},
495497
"outputs": [],
496498
"source": [
497499
"dm = Discriminator(2)\n",
@@ -584,7 +586,9 @@
584586
{
585587
"cell_type": "code",
586588
"execution_count": 101,
587-
"metadata": {},
589+
"metadata": {
590+
"collapsed": true
591+
},
588592
"outputs": [],
589593
"source": [
590594
"criterion = nn.BCELoss() # binary cross entropy loss\n",
@@ -602,7 +606,9 @@
602606
{
603607
"cell_type": "code",
604608
"execution_count": 146,
605-
"metadata": {},
609+
"metadata": {
610+
"collapsed": true
611+
},
606612
"outputs": [],
607613
"source": [
608614
"num_epochs = 2000\n",
@@ -688,7 +694,9 @@
688694
{
689695
"cell_type": "code",
690696
"execution_count": 157,
691-
"metadata": {},
697+
"metadata": {
698+
"collapsed": true
699+
},
692700
"outputs": [],
693701
"source": [
694702
"gm_init = Generator(2, 2)\n",
@@ -698,7 +706,9 @@
698706
{
699707
"cell_type": "code",
700708
"execution_count": 158,
701-
"metadata": {},
709+
"metadata": {
710+
"collapsed": true
711+
},
702712
"outputs": [],
703713
"source": [
704714
"g_input_data = g_input.data.numpy()\n",
@@ -732,13 +742,141 @@
732742
" ax.legend()"
733743
]
734744
},
745+
{
746+
"cell_type": "markdown",
747+
"metadata": {
748+
"collapsed": true
749+
},
750+
"source": [
751+
"## Conditional GANs\n",
752+
"以 mnist dataset 为例,解释 Conditional GANs"
753+
]
754+
},
755+
{
756+
"cell_type": "markdown",
757+
"metadata": {},
758+
"source": [
759+
"### MNIST 数据集读取\n",
760+
"- 一个著名的手写数字数据集"
761+
]
762+
},
735763
{
736764
"cell_type": "code",
737-
"execution_count": null,
765+
"execution_count": 77,
738766
"metadata": {
739767
"collapsed": true
740768
},
741769
"outputs": [],
770+
"source": [
771+
"import torchvision.datasets as dset\n",
772+
"import torchvision.transforms as transforms\n",
773+
"import os"
774+
]
775+
},
776+
{
777+
"cell_type": "code",
778+
"execution_count": 78,
779+
"metadata": {},
780+
"outputs": [],
781+
"source": [
782+
"if not os.path.exists('./data'):\n",
783+
" os.mkdir('./data')\n",
784+
" \n",
785+
"train_set = dset.MNIST(root='./data', train=True, \n",
786+
" transform=transforms.Compose([transforms.ToTensor(), \n",
787+
" transforms.Normalize((0.5,), (1.0,))]),\n",
788+
" download=True)"
789+
]
790+
},
791+
{
792+
"cell_type": "code",
793+
"execution_count": 70,
794+
"metadata": {},
795+
"outputs": [],
796+
"source": [
797+
"dataloader = torch.utils.data.DataLoader(train_set, batch_size=200, shuffle=True)"
798+
]
799+
},
800+
{
801+
"cell_type": "markdown",
802+
"metadata": {},
803+
"source": [
804+
"Take a look at the data"
805+
]
806+
},
807+
{
808+
"cell_type": "code",
809+
"execution_count": 71,
810+
"metadata": {},
811+
"outputs": [],
812+
"source": [
813+
"data_sample = next(iter(dataloader))"
814+
]
815+
},
816+
{
817+
"cell_type": "markdown",
818+
"metadata": {},
819+
"source": [
820+
"- dataloader 迭代返回的每个元素是一个 list, 其第一个元素是这批 batch 的图像,第二个元素是这批 batch 的 label\n",
821+
"- 第一个元素是一个 FloatTensor, 大小为 (200 x 1 x 28 x 28),即 (batch_size x channel x height x width)。因为是灰度图像,所以只有一个 channel\n",
822+
"- 因为 channel 这个维度没用,所以我们用 view 把它消解"
823+
]
824+
},
825+
{
826+
"cell_type": "code",
827+
"execution_count": 72,
828+
"metadata": {},
829+
"outputs": [
830+
{
831+
"name": "stdout",
832+
"output_type": "stream",
833+
"text": [
834+
"torch.Size([200, 1, 28, 28])\n",
835+
"torch.Size([200, 28, 28])\n"
836+
]
837+
}
838+
],
839+
"source": [
840+
"print(data_sample[0].size())\n",
841+
"print(data_sample[0].view(-1, 28, 28).size())"
842+
]
843+
},
844+
{
845+
"cell_type": "code",
846+
"execution_count": 74,
847+
"metadata": {},
848+
"outputs": [
849+
{
850+
"data": {
851+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADTpJREFUeJzt3W+oVHUex/HPV82blJGurEilpolRQqaX8sFladlV3AhM\ngsieuGzs7UGRxUaG+2CLCCSsZetBoCXa4pbbP7Ja1lIWbWEVzdQs13Triop5EwW1B2X63Qdz7u7V\n7vxmnDkzZ8bv+wWXO3O+c858Gf3cc878Zs7P3F0A4hlUdAMAikH4gaAIPxAU4QeCIvxAUIQfCIrw\nA0ERfiAowg8ENaSZT2ZmfJwQaDB3t2oeV9ee38xmm9keM9tnZo/Xsy0AzWW1frbfzAZL+kLSTEkH\nJW2RNM/dP0+sw54faLBm7PlvkbTP3b909+8lvSZpTh3bA9BE9YT/KkkH+t0/mC07h5l1m9lWM9ta\nx3MByFnD3/Bz96WSlkoc9gOtpJ49/yFJ1/S7f3W2DEAbqCf8WyRNMrNrzWyopHskrcmnLQCNVvNh\nv7v/YGYPSlorabCk5e7+WW6dAWiomof6anoyzvmBhmvKh3wAtC/CDwRF+IGgCD8QFOEHgiL8QFCE\nHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCauqlu1Gbjo6OZP2TTz4pWxsyJP1PPG3atGT91KlTyTra\nF3t+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiKcf42MH369GT9+uuvr3nbw4YNS9YZ5794secHgiL8\nQFCEHwiK8ANBEX4gKMIPBEX4gaDqGuc3sx5JJyWdkfSDu3fm0RTONWgQf6ORvzw+5PNzdz+aw3YA\nNBG7FCCoesPvkj4ws4/NrDuPhgA0R72H/V3ufsjMfirpQzP7t7tv7P+A7I8CfxiAFlPXnt/dD2W/\neyW9LemWAR6z1N07eTMQaC01h9/MLjOz4X23Jc2StCuvxgA0Vj2H/aMlvW1mfdv5i7v/PZeuADRc\nzeF39y8l3ZRjLyijq6urYdu+6ab0P+G6desa9twoFkN9QFCEHwiK8ANBEX4gKMIPBEX4gaC4dHcb\n2LFjR1tuG62NPT8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBMU4fxuo9LXblDNnziTr7l7zttHe2PMD\nQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCM81/k3njjjWT96FEmWI6KPT8QFOEHgiL8QFCEHwiK8ANB\nEX4gKMIPBFVxnN/Mlku6Q1Kvu0/Jlo2UtFrSeEk9ku529+ONaxPt6tZbby1bW7JkSXLdxYsXJ+vv\nv/9+TT2hpJo9/wpJs89b9rik9e4+SdL67D6ANlIx/O6+UdKx8xbPkbQyu71S0p059wWgwWo95x/t\n7oez219LGp1TPwCapO7P9ru7m1nZC8GZWbek7nqfB0C+at3zHzGzMZKU/e4t90B3X+rune7eWeNz\nAWiAWsO/RtL87PZ8Se/k0w6AZqkYfjN7VdK/JE02s4Nmdp+kxZJmmtleSb/M7gNoI9bM67an3htA\neXv37k3WJ06cWLY2a9as5Lrr1q2rqac+Q4cOTdbffffdsrWZM2cm1/3uu++S9dmzzx+BPteGDRuS\n9YuVu1s1j+MTfkBQhB8IivADQRF+ICjCDwRF+IGguHR3G7j00kuT9f3795et7dy5M+92zvHQQw8l\n65WG81I6OjqS9RkzZiTrUYf6qsWeHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeC4iu9LWDs2LHJ+vbt\n25P1r776qmxt+vTpNfXUp6urK1lfsWJFsj5hwoSytT179iTXnTx5crK+bdu2ZL2zM+bFo/hKL4Ak\nwg8ERfiBoAg/EBThB4Ii/EBQhB8Iiu/zt4BKY/FXXnllw5570KD03/8FCxYk66lxfEk6cOBA2dq0\nadOS627atClZv+KKK5J1pLHnB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGgKo7zm9lySXdI6nX3Kdmy\nJyT9VtI32cMWufvfGtUkGueRRx5J1u+6665k/ZtvvknW58yZU7ZWaXrvYcOGJetbtmxJ1pFWzZ5/\nhaSBJkL/o7tPzX4IPtBmKobf3TdKOtaEXgA0UT3n/A+a2U4zW25mI3LrCEBT1Br+FyVNlDRV0mFJ\nz5Z7oJl1m9lWM9ta43MBaICawu/uR9z9jLuflbRM0i2Jxy519053j3k1RaBF1RR+MxvT7+5cSbvy\naQdAs1Qz1PeqpNskjTKzg5L+IOk2M5sqySX1SLq/gT0CaICK4Xf3eQMsfrkBvaAACxcurGv9tWvX\nJuupOQeuu+665LojR45M1k+fPp2sI41P+AFBEX4gKMIPBEX4gaAIPxAU4QeC4tLdF4HUV2NfeOGF\n5LqjRo1K1isNpz3//PPJ+uDBg8vW7r333uS6I0akvzIyY8aMZB1p7PmBoAg/EBThB4Ii/EBQhB8I\nivADQRF+ICjG+VvA5s2bk/Xjx48n6zfeeGNNtWqcOnUqWZ80aVKyvmjRorK1uXPnJtd192T9qaee\nStaRxp4fCIrwA0ERfiAowg8ERfiBoAg/EBThB4KySmOpuT6ZWfOe7CLyzDPPJOuPPvpokzpprkrj\n+E8++WSyfvbs2TzbaRvubtU8jj0/EBThB4Ii/EBQhB8IivADQRF+ICjCDwRVcZzfzK6R9Iqk0ZJc\n0lJ3/5OZjZS0WtJ4ST2S7nb35BfPGeevTera95I0ZcqUsrVNmzYl1+3o6Kippz7ffvttst7T01O2\n9vTTTyfXXb16dbLezM+otJM8x/l/kPQ7d79B0gxJD5jZDZIel7Te3SdJWp/dB9AmKobf3Q+7+7bs\n9klJuyVdJWmOpJXZw1ZKurNRTQLI3wWd85vZeEk3S9osabS7H85KX6t0WgCgTVR9DT8zu1zSm5Ie\ndvcTZv8/rXB3L3c+b2bdkrrrbRRAvqra85vZJSoFf5W7v5UtPmJmY7L6GEm9A63r7kvdvdPdO/No\nGEA+KobfSrv4lyXtdvfn+pXWSJqf3Z4v6Z382wPQKNUM9XVJ+kjSp5L6viO5SKXz/r9KGitpv0pD\nfccqbIuxmSZbtWpVsj5v3ry6tv/YY48l60uWLKlr+7hw1Q71VTznd/d/Siq3sV9cSFMAWgef8AOC\nIvxAUIQfCIrwA0ERfiAowg8ExaW7L3LDhw9P1seNG5esv/7668n6iRMnkvVly5aVrb300kvJdVEb\nLt0NIInwA0ERfiAowg8ERfiBoAg/EBThB4JinB+4yDDODyCJ8ANBEX4gKMIPBEX4gaAIPxAU4QeC\nIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4KqGH4zu8bM/mFmn5vZZ2a2IFv+hJkdMrPt2c/t\njW8XQF4qXszDzMZIGuPu28xsuKSPJd0p6W5Jp9x9SdVPxsU8gIar9mIeQ6rY0GFJh7PbJ81st6Sr\n6msPQNEu6JzfzMZLulnS5mzRg2a208yWm9mIMut0m9lWM9taV6cAclX1NfzM7HJJGyQ97e5vmdlo\nSUcluaSnVDo1+E2FbXDYDzRYtYf9VYXfzC6R9J6kte7+3AD18ZLec/cpFbZD+IEGy+0CnmZmkl6W\ntLt/8LM3AvvMlbTrQpsEUJxq3u3vkvSRpE8lnc0WL5I0T9JUlQ77eyTdn705mNoWe36gwXI97M8L\n4Qcaj+v2A0gi/EBQhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwiK8ANBEX4gKMIPBFXx\nAp45Oyppf7/7o7JlrahVe2vVviR6q1WevY2r9oFN/T7/j57cbKu7dxbWQEKr9taqfUn0VquieuOw\nHwiK8ANBFR3+pQU/f0qr9taqfUn0VqtCeiv0nB9AcYre8wMoSCHhN7PZZrbHzPaZ2eNF9FCOmfWY\n2afZzMOFTjGWTYPWa2a7+i0baWYfmtne7PeA06QV1FtLzNycmFm60Neu1Wa8bvphv5kNlvSFpJmS\nDkraImmeu3/e1EbKMLMeSZ3uXviYsJn9TNIpSa/0zYZkZs9IOubui7M/nCPcfWGL9PaELnDm5gb1\nVm5m6V+rwNcuzxmv81DEnv8WSfvc/Ut3/17Sa5LmFNBHy3P3jZKOnbd4jqSV2e2VKv3naboyvbUE\ndz/s7tuy2ycl9c0sXehrl+irEEWE/ypJB/rdP6jWmvLbJX1gZh+bWXfRzQxgdL+Zkb6WNLrIZgZQ\ncebmZjpvZumWee1qmfE6b7zh92Nd7j5N0q8kPZAd3rYkL52ztdJwzYuSJqo0jdthSc8W2Uw2s/Sb\nkh529xP9a0W+dgP0VcjrVkT4D0m6pt/9q7NlLcHdD2W/eyW9rdJpSis50jdJava7t+B+/sfdj7j7\nGXc/K2mZCnztspml35S0yt3fyhYX/toN1FdRr1sR4d8iaZKZXWtmQyXdI2lNAX38iJldlr0RIzO7\nTNIstd7sw2skzc9uz5f0ToG9nKNVZm4uN7O0Cn7tWm7Ga3dv+o+k21V6x/8/kn5fRA9l+pogaUf2\n81nRvUl6VaXDwNMqvTdyn6SfSFovaa+kdZJGtlBvf1ZpNuedKgVtTEG9dal0SL9T0vbs5/aiX7tE\nX4W8bnzCDwiKN/yAoAg/EBThB4Ii/EBQhB8IivADQRF+ICjCDwT1X4+CPOHbV0x+AAAAAElFTkSu\nQmCC\n",
852+
"text/plain": [
853+
"<matplotlib.figure.Figure at 0x1114ba128>"
854+
]
855+
},
856+
"metadata": {},
857+
"output_type": "display_data"
858+
}
859+
],
860+
"source": [
861+
"plt.imshow(data_sample[0].view(-1, 28, 28)[0].numpy(), cmap='gray');"
862+
]
863+
},
864+
{
865+
"cell_type": "markdown",
866+
"metadata": {},
867+
"source": [
868+
"对于这样一个数据集,我们想用 GANs 的 Generator 自己产生对应的数字。如果用原始的 GANs 形式,需要针对 0-9 这 10 个数字产生对应 10 组模型。对更多类的问题,需要的模型会更多。而且各个模型之间互相独立,没有办法「共享」一部分他们学习的知识,这不是一个很好的解决方案。\n",
869+
"\n",
870+
"针对这个问题,Conditional GANs [1] 给出的方案是除了噪音输入 Z,给将每一个 $y$ 的信息也输入给 DM 及 GM(比如 embedding 形式)。其模型示意图(取自 [1])及问题形式如下:\n",
871+
"\n",
872+
"![](./imgs/mirza-arxiv-14.png)\n",
873+
"\n",
874+
"$$ \\min_G\\max_D V(D, G) = E_{x \\sim P_{data}(x|y)}[\\log D(x)] + E_{z \\sim P_z(z)}[\\log (1 - D(G(z|y))]$$"
875+
]
876+
},
877+
{
878+
"cell_type": "markdown",
879+
"metadata": {},
742880
"source": []
743881
}
744882
],

0 commit comments

Comments
 (0)