直接贴代码:
首先导入相应的包:
import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim from torchvision.utils import save_image
然后对数据进行预处理:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])
])
train_set = torchvision.datasets.MNIST(
root="./data/mnist",
train=True,
download=True,
transform=transform
)
train_loader = torch.utils.data.DataLoader(
dataset=train_set,
batch_size=8,
shuffle=True
)
定义网络模型:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1,32, 5, padding=2), #[32, 28, 28]
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2) # [32, 14, 14]
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5, padding=2), #[64, 14, 14]
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2) # [64, 7, 7]
)
self.fc = nn.Sequential(
nn.Linear(64*7*7, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # [batch_size, 64*7*7]
out = self.fc(x)
return out
class Generator(nn.Module):
def __init__(self, input_size, num_feature):
super(Generator, self).__init__()
self.fc = nn.Linear(input_size, num_feature) #[100 -> 56*56]
self.br = nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.conv1 = nn.Sequential(
nn.Conv2d(1, 50, 3, stride=1, padding=1), #[50, 56, 56]
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(50, 25, 3, stride=1, padding=1), #[25, 56, 56]
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(25, 1, 2, stride=2), #[1, 28, 28]
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 1, 56, 56)
x = self.br(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
设置超参数以及初始化模型:
lr = 3e-4 batch_size = 8 z_dimension = 100
n_epoch = 20 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") G = Generator(z_dimension, 56*56).to(device) D = Discriminator().to(device) G_optimizer = optim.Adam(G.parameters(), lr) D_optimizer = optim.Adam(D.parameters(), lr)
开始训练:
for epoch in range(n_epoch):
for (index, data) in enumerate(train_loader):
(images,labels) = data
real_images = images.to(device)
real_labels = torch.ones(batch_size, 1).cuda()
fake_labels = torch.zeros(batch_size, 1).cuda()
# train Discriminator
real_out = D(real_images)
D_loss_real = criterion(real_out, real_labels)
#generate fake images
z = torch.randn(batch_size, z_dimension).cuda()
fake_images = G(z)
fake_out = D(fake_images)
D_loss_fake = criterion(fake_out, fake_labels)
D_loss = (D_loss_real + D_loss_fake) / 2.0
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
#train Generator
z = torch.randn(batch_size, z_dimension).cuda()
fake_images = G(z)
fake_out = D(fake_images)
G_loss = criterion(fake_out,real_labels)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
if (index + 1) % 100 == 0:
print("[%d/%d] [%d/%d] G_loss: %.06f D_loss: %.06f"
% (epoch+1, n_epoch, index+1, len(train_loader), G_loss.item(), D_loss.item()))
z = torch.randn(batch_size, z_dimension).cuda()
imgs = G(z)
save_image(imgs, ‘./images/images_%d.png‘ % (epoch + 1))
这是第一个epoch的结果:

这是第10个epoch的结果:

这是第20个epoch的结果:

可以看到生成的图片质量原来越好,即越来越像训练数据中的图片。
for epoch in range(n_epoch): for (index, data) in enumerate(train_loader): (images,labels) = data real_images = images.to(device) real_labels = torch.ones(batch_size).cuda() fake_labels = torch.zeros(batch_size).cuda() # train Discriminator real_out = D(real_images) D_loss_real = criterion(real_out, real_labels) #generate fake images z = torch.randn(batch_size, z_dimension).cuda() fake_images = G(z) fake_out = D(fake_images) D_loss_fake = criterion(fake_out, fake_labels) D_loss = (D_loss_real + D_loss_fake) / 2.0 D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() #train Generator z = torch.randn(batch_size, z_dimension).cuda() fake_images = G(z) fake_out = D(fake_images) G_loss = criterion(fake_out,real_labels) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() if (index + 1) % 100 == 0: print("[%d/%d] [%d/%d] G_loss: %.06f D_loss: %.06f" % (epoch+1, n_epoch, index+1, len(train_loader), G_loss.item(), D_loss.item())) z = torch.randn(batch_size, z_dimension).cuda() imgs = G(z) save_image(imgs, ‘./images/images_%d.png‘ % (epoch + 1))
原文:https://www.cnblogs.com/liualexsone/p/12543346.html