网络定义

定义网络,莫凡讲的都是用sequential搭

定义优化器, G=torch.optim.Adam(G.parameter(),lr=LR)

  先得有真实的数据,之前可以定义 def artist_work():

                ...

                return paintings

在训练过程中才定义损失函数

for step in range(10000):

  艺术作品 artist_paintings = artist_work()   # 调用这个函数,返回 return paintings

    生成的作品 G_paintings = G(G_ideas)      # 用随机灵感作为输入,输入到generator中,得到G_paingtings

  而 G_ideas 是怎么来的,用uniform随机生成的五个,,,因为用的是批次训练,所以写G_ideas = torch.randn( , G_ideas)

  训练过程中的损失函数,G和D肯定是分开的,,,G希望生成样本的值越大越好,D希望真实样本值越大越好,生成样本的值越小越小。

  G_loss

  D_loss

  

  每一次训练的时候都需要梯度清零,让优化器梯度清零

  opt_D.zero_grad()                   opt_G.zero_grad()

  D_loss.backward()                  G_loss.backward()

  opt_D.step()                            opt_G.step()

这样就可以训练了

  

原文地址:https://www.cnblogs.com/DoctorZhao/p/13355481.html