当先锋百科网

首页 1 2 3 4 5 6 7

本文参考:pytorch实现简单GAN - 灰信网(软件开发博客聚合)

上文中pytorch代码执行会有问题,这块本文中已经修复! 

1、GAN概述

GAN:Generative Adversarial Nets,生成对抗网络。在给定充分的建模能力,两个博弈模型能够通过简单的反向传播来协同训练。

两个模型的角色定位十分鲜明。给定真实数据集Data,G是生成器(Generator),它的任务是生成能以假乱真的假数据。D是判别器(Discriminator),它从真实数据或者G那里获取数据,然后做出判别真假的标记。

理想情况下,D和G都会随着不断训练做得越来越好,直到G基本上成为一个“赝品制造大事”,而D因无法正确区分两种数据分布输给G。

2、数学建模

设真实数据的概率分布为Pdata,生成器生成数据的概率分布为PG。

(1)D的数学描述

规定D的输出代表输入为”真”的概率(在0~1之间),则D的目标是:

若输入是真品,则提高D(x);若输入是赝品,则降低D(x)。

综合起来用数学语言描述如下:

解释:若x服从p_{data},则log(D(x))越大越好。若x服从P_{G},则log(D(x))越小越好,即log(1-D(x))越大越好。

(2)G的数学描述

对于G来说,它的目标是尽可能提高生成数据被D判别为”真”的概率,数学描述如下:

也即:

(3)全局最优解

生成器生成数据的分布在最优解情况下就等于真实数据的分布。

3、用pytorch实现简单GAN

import numpy as np
import torch.nn as nn
import torch
import matplotlib.pyplot as plt

LR = 0.0001
BATCH_SIZE = 64
DATA_SIZE = 16
IDEA = 5
X = np.linspace(0, 2 * np.pi, DATA_SIZE)


def p_data(x):
    f = np.zeros((BATCH_SIZE, DATA_SIZE))
    for i in range(BATCH_SIZE):
        f[i] = np.sin(x)
    return f


G = nn.Sequential(
    nn.Linear(IDEA, 64),
    nn.ReLU(),
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, DATA_SIZE)
)

D = nn.Sequential(
    nn.Linear(DATA_SIZE, 64),
    nn.ReLU(),
    nn.Linear(64, 256),
    nn.ReLU(),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

D_optimizer = torch.optim.Adam(D.parameters(), lr=LR)
G_optimizer = torch.optim.Adam(G.parameters(), lr=LR)

for step in range(10000):
    real = torch.tensor(p_data(X)).float()
    idea = torch.randn((BATCH_SIZE, IDEA))
    fake = G(idea)

    prob_fake = D(fake)
    G_loss = torch.mean(torch.log(torch.tensor(1) - prob_fake))
    G_optimizer.zero_grad()
    G_loss.backward()
    G_optimizer.step()

    prob_real = D(real)
    prob_fake = D(fake.detach())
    D_loss = -torch.mean((torch.log(prob_real) + torch.log(torch.tensor(1) - prob_fake)))
    D_optimizer.zero_grad()
    D_loss.backward(retain_graph=True)
    D_optimizer.step()

    if step % 100 == 0:
        print(prob_real.mean())
        print(prob_fake.mean())
        print('-----------------------------------------------')
    if torch.abs(prob_real.mean() - 0.5) <= 1.e-6:
        break
    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(X, fake.data.numpy()[0], c='red', lw=3, label='Generated painting')
        plt.plot(X, real.data.numpy()[0], c='black', lw=1, label='real painting')
        plt.text(1, .5, 'the prob of Generated painting is real = %.2f' % prob_fake.data.numpy().mean())
        plt.ylim((-1.1, 1.1))
        plt.legend(loc='best', fontsize=10)
        plt.draw()
        plt.pause(0.01)

plt.ioff()
plt.show()