博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用Keras编写GAN的入门
阅读量:4959 次
发布时间:2019-06-12

本文共 8090 字,大约阅读时间需要 26 分钟。

使用Keras编写GAN的入门

Time: 2017-5-31


前言

主要参考了网页[1]的教程,同时主要算法来自Ian J. Goodfellow 的论文,算法如下:

gan
gan

代码

%matplotlib inlineimport numpy as npimport pandas as pdfrom keras.models import Modelfrom keras.layers import Dense, Activation, Input, Reshapefrom keras.layers import Conv1D, Flatten, Dropoutfrom keras.optimizers import SGD, Adamfrom tqdm import tqdm_notebook as tqdm  # 进度条# 生成随机正弦曲线的数据def sample_data(n_samples=10000, x_vals=np.arange(0, 5, .1), max_offset=1000, mul_range=[1, 2]):    vectors = []    for i in range(n_samples):        offset = np.random.random() * max_offset        mul = mul_range[0] + np.random.random() * (mul_range[1] - mul_range[0])        vectors.append(np.sin(offset + x_vals * mul) / 2 + .5)            return np.array(vectors)    # 创建生成模型def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):    x = Dense(dense_dim)(G_in)    x = Activation('tanh')(x)    G_out = Dense(out_dim, activation='tanh')(x)    G = Model(G_in, G_out)    opt = SGD(lr=lr)        G.compile(loss='binary_crossentropy', optimizer=opt)        return G, G_out    # 创建判别模型def get_discriminative(D_in, lr=1e-3, drate = .25, n_channels=50, conv_sz=5, leak=.2):    x = Reshape((-1, 1))(D_in)    x = Conv1D(n_channels, conv_sz, activation='relu')(x)    x = Dropout(drate)(x)    x = Flatten()(x)    x = Dense(n_channels)(x)    D_out = Dense(2, activation='sigmoid')(x)    D = Model(D_in, D_out)    dopt = Adam(lr=lr)    D.compile(loss='binary_crossentropy', optimizer=dopt)        return D, D_out        def set_trainability(model, trainable=False):    model.trainable = trainable    for layer in model.layers:        layer.trainable = trainable        def make_gan(GAN_in, G, D):    set_trainability(D, False)    x = G(GAN_in)    GAN_out = D(x)    GAN = Model(GAN_in, GAN_out)    GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)    return GAN, GAN_out# 通过生成数据 预训练判别模型def sample_data_and_gen(G, noise_dim=10, n_samples=10000):    XT = sample_data(n_samples=n_samples)    XN_noise = np.random.uniform(0, 1, size=[n_samples, noise_dim])    XN = G.predict(XN_noise)    X = np.concatenate((XT, XN))    y = np.zeros((2*n_samples, 2))    y[:n_samples, 1] = 1    y[n_samples:, 0] = 1    return X, y     def pretrain(G, D, noise_dim=10, n_samples=10000, batch_size=32):    X, y = sample_data_and_gen(G, noise_dim=noise_dim, n_samples=n_samples)    set_trainability(D, True)    D.fit(X, y, epochs=1, batch_size=batch_size)        # 开始交叉训练步骤def sample_noise(G, noise_dim=10, n_samples=10000):    X = np.random.uniform(0, 1, size=[n_samples, noise_dim])    y = np.zeros((n_samples, 2))    y[:, 1] = 1    return X, y    def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):    d_loss = []    g_loss = []    e_range = range(epochs)    if verbose:        e_range = tqdm(e_range)        for epoch in e_range:        X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim) # 对D进行训练        set_trainability(D, True)        d_loss.append(D.train_on_batch(X, y))                X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim) # 对G训练        set_trainability(D, False)        g_loss.append(GAN.train_on_batch(X, y))                if verbose and (epoch + 1) % v_freq == 0:            print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))                return d_loss, g_loss
ax = pd.DataFrame(np.transpose(sample_data(5))).plot()G_in = Input(shape=[10])G, G_out = get_generative(G_in)G.summary()D_in = Input(shape=[50])D, D_out = get_discriminative(D_in)D.summary()
_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_9 (InputLayer)         (None, 10)                0         _________________________________________________________________dense_13 (Dense)             (None, 200)               2200      _________________________________________________________________activation_4 (Activation)    (None, 200)               0         _________________________________________________________________dense_14 (Dense)             (None, 50)                10050     =================================================================Total params: 12,250Trainable params: 12,250Non-trainable params: 0__________________________________________________________________________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_10 (InputLayer)        (None, 50)                0         _________________________________________________________________reshape_4 (Reshape)          (None, 50, 1)             0         _________________________________________________________________conv1d_4 (Conv1D)            (None, 46, 50)            300       _________________________________________________________________dropout_4 (Dropout)          (None, 46, 50)            0         _________________________________________________________________flatten_4 (Flatten)          (None, 2300)              0         _________________________________________________________________dense_15 (Dense)             (None, 50)                115050    _________________________________________________________________dense_16 (Dense)             (None, 2)                 102       =================================================================Total params: 115,452Trainable params: 115,452Non-trainable params: 0_________________________________________________________________

png
png

GAN_in = Input([10])GAN, GAN_out = make_gan(GAN_in, G, D)GAN.summary()
_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_11 (InputLayer)        (None, 10)                0         _________________________________________________________________model_9 (Model)              (None, 50)                12250     _________________________________________________________________model_10 (Model)             (None, 2)                 115452    =================================================================Total params: 127,702Trainable params: 12,250Non-trainable params: 115,452_________________________________________________________________
pretrain(G, D)
Epoch 1/120000/20000 [==============================] - 3s - loss: 0.0072
d_loss, g_loss = train(GAN, G, D, verbose=True)
Epoch #50: Generative Loss: 4.41527795791626, Discriminative Loss: 0.6733301877975464Epoch #100: Generative Loss: 3.8898046016693115, Discriminative Loss: 0.09901376813650131Epoch #150: Generative Loss: 6.2410054206848145, Discriminative Loss: 0.034074194729328156Epoch #200: Generative Loss: 5.206066608428955, Discriminative Loss: 0.13078376650810242Epoch #250: Generative Loss: 3.5144925117492676, Discriminative Loss: 0.07160962373018265Epoch #300: Generative Loss: 3.705162525177002, Discriminative Loss: 0.05893774330615997Epoch #350: Generative Loss: 3.511479616165161, Discriminative Loss: 0.09775738418102264Epoch #400: Generative Loss: 4.141300678253174, Discriminative Loss: 0.03169865906238556Epoch #450: Generative Loss: 3.500260829925537, Discriminative Loss: 0.05957922339439392Epoch #500: Generative Loss: 2.9797921180725098, Discriminative Loss: 0.10566817969083786
ax = pd.DataFrame(    {        'Generative Loss': g_loss,        'Discriminative Loss': d_loss,    }).plot(title='Training loss', logy=True)ax.set_xlabel("Epochs")ax.set_ylabel("Loss")

png
png

N_VIEWED_SAMPLES = 2data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).plot()

png
png

N_VIEWED_SAMPLES = 2data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()

png
png

reference

[1]

转载于:https://www.cnblogs.com/flyu6/p/7691130.html

你可能感兴趣的文章
mysql 不常用备忘
查看>>
Mybatis自动化生成代码
查看>>
asp.net 动态添加多附件上传.
查看>>
sscanf()函数
查看>>
WEEX学习网站
查看>>
uDig介绍
查看>>
后台调用外部程序的完美实现
查看>>
python随机数random模块
查看>>
03-body标签中相关标签
查看>>
JavaScript:对Object对象的一些常用操作总结
查看>>
node assert.equal()
查看>>
buf.readUIntBE()
查看>>
Beta 冲刺(1/7)
查看>>
【luogu2747】 [USACO5.4]周游加拿大Canada Tour[动态规划]
查看>>
ubuntu安装mysql 时未提示输入密码
查看>>
L1-006 连续因子
查看>>
RabbitMQ入门(4)——路由(Routing)
查看>>
POJ 1330
查看>>
poj 3687(拓扑排序)
查看>>
jar 打包命令详解
查看>>