PyTorch 实现Image to Image (pix2pix)
目录
一、前言
pix2pix可以实现画风迁移、图片修复、语义分割、图片上色等等。这里主要实现将facades labels转换成真实图片。由于源码中涉及的功能较多,代码较为复杂,这里我将主要部分拿了出来并做了一些小修改。
在传统的GAN中,生成器G的输入是噪声z,目的是让G去学习一种分布去拟合真实数据的分布。而在pix2pix中,G的输入是facades labels,生成器G一方面要尽可能生成真的图片去骗过D,另一方面要使输出尽可能去逼近facades labels对应的真实图片(利用了L1loss)。G的噪声是通过在网络的dropout层引入的(如果不引入噪声,那么最终训练的G只会逼近已有的训练样本,无法拟合真实数据的分布) 。
pix2pix中判别器D采用了PatchGAN,即D不仅仅只对整张图的真假打分,而是将图片最终分成多个patch,对每个patch进行打分,D只有当图片与fecades labels匹配并且尽可能逼近真实样本时候才会打高分。
二、数据集
数据集采用的是facades数据集,训练样本是成对匹配的(其中fecades图片均以png格式存储,对应真实样本以jpg存储),共有606张图片,训练过程中将数据集按8:2的比例划分训练集和验证集。部分图像对如下图所示。



三、网络结构
生成器G采用了Unet结构,本质为encoder-decoder,网络结构如下图所示(这里以输入图片为256*256为例)。dropout层只在输入输出channel都为512时候才有。

判别器D其实就是简单的卷积神经网络,网络结构如下图所示。

四、代码
(一)net
这里提供了输入为128*128和256*256两种结构的网络
import torch.nn as nn
import torch
from collections import OrderedDict
# 定义降采样部分
class downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(downsample, self).__init__()
self.down = nn.Sequential(
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
return self.down(x)
# 定义上采样部分
class upsample(nn.Module):
def __init__(self, in_channels, out_channels, drop_out=False):
super(upsample, self).__init__()
self.up = nn.Sequential(
nn.ReLU(True),
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.Dropout(0.5) if drop_out else nn.Identity()
)
def forward(self, x):
return self.up(x)
# ---------------------------------------------------------------------------------
# 定义pix_G =>input 128*128
class pix2pixG_128(nn.Module):
def __init__(self):
super(pix2pixG_128, self).__init__()
# down sample
self.down_1 = nn.Conv2d(3, 64, 4, 2, 1) # [batch,3,128,128]=>[batch,64,64,64]
for i in range(7):
if i == 0:
self.down_2 = downsample(64, 128) # [batch,64,64,64]=>[batch,128,32,32]
self.down_3 = downsample(128, 256) # [batch,128,32,32]=>[batch,256,16,16]
self.down_4 = downsample(256, 512) # [batch,256,16,16]=>[batch,512,8,8]
self.down_5 = downsample(512, 512) # [batch,512,8,8]=>[batch,512,4,4]
self.down_6 = downsample(512, 512) # [batch,512,4,4]=>[batch,512,2,2]
self.down_7 = downsample(512, 512) # [batch,512,2,2]=>[batch,512,1,1]
# up_sample
for i in range(7):
if i == 0:
self.up_1 = upsample(512, 512) # [batch,512,1,1]=>[batch,512,2,2]
self.up_2 = upsample(1024, 512, drop_out=True) # [batch,1024,2,2]=>[batch,512,4,4]
self.up_3 = upsample(1024, 512, drop_out=True) # [batch,1024,4,4]=>[batch,512,8,8]
self.up_4 = upsample(1024, 256) # [batch,1024,8,8]=>[batch,256,16,16]
self.up_5 = upsample(512, 128) # [batch,512,16,16]=>[batch,128,32,32]
self.up_6 = upsample(256, 64) # [batch,256,32,32]=>[batch,64,64,64]
self.last_Conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
self.init_weight()
def init_weight(self):
for w in self.modules():
if isinstance(w, nn.Conv2d):
nn.init.kaiming_normal_(w.weight, mode='fan_out')
if w.bias is not None:
nn.init.zeros_(w.bias)
elif isinstance(w, nn.ConvTranspose2d):
nn.init.kaiming_normal_(w.weight, mode='fan_in')
elif isinstance(w, nn.BatchNorm2d):
nn.init.ones_(w.weight)
nn.init.zeros_(w.bias)
def forward(self, x):
# down
down_1 = self.down_1(x)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_6 = self.down_6(down_5)
down_7 = self.down_7(down_6)
# up
up_1 = self.up_1(down_7)
up_2 = self.up_2(torch.cat([up_1, down_6], dim=1))
up_3 = self.up_3(torch.cat([up_2, down_5], dim=1))
up_4 = self.up_4(torch.cat([up_3, down_4], dim=1))
up_5 = self.up_5(torch.cat([up_4, down_3], dim=1))
up_6 = self.up_6(torch.cat([up_5, down_2], dim=1))
out = self.last_Conv(torch.cat([up_6, down_1], dim=1))
return out
# 定义pix_D_128 => input 128*128
class pix2pixD_128(nn.Module):
def __init__(self):
super(pix2pixD_128, self).__init__()
# 定义基本的卷积\bn\relu
def base_Conv_bn_lkrl(in_channels, out_channels, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
D_dic = OrderedDict()
in_channels = 6
out_channels = 64
for i in range(4):
if i < 3:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
else:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
in_channels = out_channels
out_channels *= 2
D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)}) # [batch,1,14,14]
self.D_model = nn.Sequential(D_dic)
def forward(self, x1, x2):
in_x = torch.cat([x1, x2], dim=1)
return self.D_model(in_x)
# ---------------------------------------------------------------------------------
# 256*256
class pix2pixG_256(nn.Module):
def __init__(self):
super(pix2pixG_256, self).__init__()
# down sample
self.down_1 = nn.Conv2d(3, 64, 4, 2, 1) # [batch,3,256,256]=>[batch,64,128,128]
for i in range(7):
if i == 0:
self.down_2 = downsample(64, 128) # [batch,64,128,128]=>[batch,128,64,64]
self.down_3 = downsample(128, 256) # [batch,128,64,64]=>[batch,256,32,32]
self.down_4 = downsample(256, 512) # [batch,256,32,32]=>[batch,512,16,16]
self.down_5 = downsample(512, 512) # [batch,512,16,16]=>[batch,512,8,8]
self.down_6 = downsample(512, 512) # [batch,512,8,8]=>[batch,512,4,4]
self.down_7 = downsample(512, 512) # [batch,512,4,4]=>[batch,512,2,2]
self.down_8 = downsample(512, 512) # [batch,512,2,2]=>[batch,512,1,1]
# up_sample
for i in range(7):
if i == 0:
self.up_1 = upsample(512, 512) # [batch,512,1,1]=>[batch,512,2,2]
self.up_2 = upsample(1024, 512, drop_out=True) # [batch,1024,2,2]=>[batch,512,4,4]
self.up_3 = upsample(1024, 512, drop_out=True) # [batch,1024,4,4]=>[batch,512,8,8]
self.up_4 = upsample(1024, 512) # [batch,1024,8,8]=>[batch,512,16,16]
self.up_5 = upsample(1024, 256) # [batch,1024,16,16]=>[batch,256,32,32]
self.up_6 = upsample(512, 128) # [batch,512,32,32]=>[batch,128,64,64]
self.up_7 = upsample(256, 64) # [batch,256,64,64]=>[batch,64,128,128]
self.last_Conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
self.init_weight()
def init_weight(self):
for w in self.modules():
if isinstance(w, nn.Conv2d):
nn.init.kaiming_normal_(w.weight, mode='fan_out')
if w.bias is not None:
nn.init.zeros_(w.bias)
elif isinstance(w, nn.ConvTranspose2d):
nn.init.kaiming_normal_(w.weight, mode='fan_in')
elif isinstance(w, nn.BatchNorm2d):
nn.init.ones_(w.weight)
nn.init.zeros_(w.bias)
def forward(self, x):
# down
down_1 = self.down_1(x)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_6 = self.down_6(down_5)
down_7 = self.down_7(down_6)
down_8 = self.down_8(down_7)
# up
up_1 = self.up_1(down_8)
up_2 = self.up_2(torch.cat([up_1, down_7], dim=1))
up_3 = self.up_3(torch.cat([up_2, down_6], dim=1))
up_4 = self.up_4(torch.cat([up_3, down_5], dim=1))
up_5 = self.up_5(torch.cat([up_4, down_4], dim=1))
up_6 = self.up_6(torch.cat([up_5, down_3], dim=1))
up_7 = self.up_7(torch.cat([up_6, down_2], dim=1))
out = self.last_Conv(torch.cat([up_7, down_1], dim=1))
return out
# 256*256
class pix2pixD_256(nn.Module):
def __init__(self):
super(pix2pixD_256, self).__init__()
# 定义基本的卷积\bn\relu
def base_Conv_bn_lkrl(in_channels, out_channels, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
D_dic = OrderedDict()
in_channels = 6
out_channels = 64
for i in range(4):
if i < 3:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
else:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
in_channels = out_channels
out_channels *= 2
D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)}) # [batch,1,30,30]
self.D_model = nn.Sequential(D_dic)
def forward(self, x1, x2):
in_x = torch.cat([x1, x2], dim=1)
return self.D_model(in_x)
(二)dataset
split_data.py
import random
import glob
def split_data(dir_root):
random.seed(0)
ori_img = glob.glob(dir_root + '/*.png')
k = 0.2
train_ori_imglist = []
val_ori_imglist = []
sample_data = random.sample(population=ori_img, k=int(k * len(ori_img)))
for img in ori_img:
if img in sample_data:
val_ori_imglist.append(img)
else:
train_ori_imglist.append(img)
return train_ori_imglist, val_ori_imglist
if __name__ == '__main__':
a, b= split_data('../base')
mydatasets.py
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transform
from PIL import Image
import cv2
class CreateDatasets(Dataset):
def __init__(self, ori_imglist,img_size):
self.ori_imglist = ori_imglist
self.transform = transform.Compose([
transform.ToTensor(),
transform.Resize((img_size, img_size)),
transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def __len__(self):
return len(self.ori_imglist)
def __getitem__(self, item):
ori_img = cv2.imread(self.ori_imglist[item])
ori_img = ori_img[:, :, ::-1]
real_img = Image.open(self.ori_imglist[item].replace('.png', '.jpg'))
ori_img = self.transform(ori_img.copy())
real_img = self.transform(real_img)
return ori_img, real_img
(三)train
from torch.utils.tensorboard import SummaryWriter
from pix2Topix import pix2pixG_256, pix2pixD_256
import argparse
from mydatasets import CreateDatasets
from split_data import split_data
import os
from torch.utils.data.dataloader import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
from utils import train_one_epoch, val
def train(opt):
batch = opt.batch
data_path = opt.dataPath
print_every = opt.every
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = opt.epoch
img_size = opt.imgsize
if not os.path.exists(opt.savePath):
os.mkdir(opt.savePath)
# 加载数据集
train_imglist, val_imglist = split_data(data_path)
train_datasets = CreateDatasets(train_imglist, img_size)
val_datasets = CreateDatasets(val_imglist, img_size)
train_loader = DataLoader(dataset=train_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
drop_last=True)
val_loader = DataLoader(dataset=val_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
drop_last=True)
# 实例化网络
pix_G = pix2pixG_256().to(device)
pix_D = pix2pixD_256().to(device)
# 定义优化器和损失函数
optim_G = optim.Adam(pix_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_D = optim.Adam(pix_D.parameters(), lr=0.0002, betas=(0.5, 0.999))
loss = nn.MSELoss()
l1_loss = nn.L1Loss()
start_epoch = 0
# 加载预训练权重
if opt.weight != '':
ckpt = torch.load(opt.weight)
pix_G.load_state_dict(ckpt['G_model'], strict=False)
pix_D.load_state_dict(ckpt['D_model'], strict=False)
start_epoch = ckpt['epoch'] + 1
writer = SummaryWriter('train_logs')
# 开始训练
for epoch in range(start_epoch, epochs):
loss_mG, loss_mD = train_one_epoch(G=pix_G, D=pix_D, train_loader=train_loader,
optim_G=optim_G, optim_D=optim_D, writer=writer, loss=loss, device=device,
plot_every=print_every, epoch=epoch, l1_loss=l1_loss)
writer.add_scalars(main_tag='train_loss', tag_scalar_dict={
'loss_G': loss_mG,
'loss_D': loss_mD
}, global_step=epoch)
# 验证集
val(G=pix_G, D=pix_D, val_loader=val_loader, loss=loss, l1_loss=l1_loss, device=device, epoch=epoch)
# 保存模型
torch.save({
'G_model': pix_G.state_dict(),
'D_model': pix_D.state_dict(),
'epoch': epoch
}, './weights/pix2pix_256.pth')
def cfg():
parse = argparse.ArgumentParser()
parse.add_argument('--batch', type=int, default=16)
parse.add_argument('--epoch', type=int, default=200)
parse.add_argument('--imgsize', type=int, default=256)
parse.add_argument('--dataPath', type=str, default='../base', help='data root path')
parse.add_argument('--weight', type=str, default='', help='load pre train weight')
parse.add_argument('--savePath', type=str, default='./weights', help='weight save path')
parse.add_argument('--numworker', type=int, default=4)
parse.add_argument('--every', type=int, default=2, help='plot train result every * iters')
opt = parse.parse_args()
return opt
if __name__ == '__main__':
opt = cfg()
print(opt)
train(opt)
utils.py
import torchvision
from tqdm import tqdm
import torch
import os
def train_one_epoch(G, D, train_loader, optim_G, optim_D, writer, loss, device, plot_every, epoch, l1_loss):
pd = tqdm(train_loader)
loss_D, loss_G = 0, 0
step = 0
G.train()
D.train()
for idx, data in enumerate(pd):
in_img = data[0].to(device)
real_img = data[1].to(device)
# 先训练D
fake_img = G(in_img)
D_fake_out = D(fake_img.detach(), in_img).squeeze()
D_real_out = D(real_img, in_img).squeeze()
ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
ls_D2 = loss(D_real_out, torch.ones(D_real_out.size()).cuda())
ls_D = (ls_D1 + ls_D2) * 0.5
optim_D.zero_grad()
ls_D.backward()
optim_D.step()
# 再训练G
fake_img = G(in_img)
D_fake_out = D(fake_img, in_img).squeeze()
ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
ls_G2 = l1_loss(fake_img, real_img)
ls_G = ls_G1 + ls_G2 * 100
optim_G.zero_grad()
ls_G.backward()
optim_G.step()
loss_D += ls_D
loss_G += ls_G
pd.desc = 'train_{} G_loss: {} D_loss: {}'.format(epoch, ls_G.item(), ls_D.item())
# 绘制训练结果
if idx % plot_every == 0:
writer.add_images(tag='train_epoch_{}'.format(epoch), img_tensor=0.5 * (fake_img + 1), global_step=step)
step += 1
mean_lsG = loss_G / len(train_loader)
mean_lsD = loss_D / len(train_loader)
return mean_lsG, mean_lsD
@torch.no_grad()
def val(G, D, val_loader, loss, device, l1_loss, epoch):
pd = tqdm(val_loader)
loss_D, loss_G = 0, 0
G.eval()
D.eval()
all_loss = 10000
for idx, item in enumerate(pd):
in_img = item[0].to(device)
real_img = item[1].to(device)
fake_img = G(in_img)
D_fake_out = D(fake_img, in_img).squeeze()
ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
ls_D = ls_D1 * 0.5
ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
ls_G2 = l1_loss(fake_img, real_img)
ls_G = ls_G1 + ls_G2 * 100
loss_G += ls_G
loss_D += ls_D
pd.desc = 'val_{}: G_loss:{} D_Loss:{}'.format(epoch, ls_G.item(), ls_D.item())
# 保存最好的结果
all_ls = ls_G + ls_D
if all_ls < all_loss:
all_loss = all_ls
best_image = fake_img
result_img = (best_image + 1) * 0.5
if not os.path.exists('./results'):
os.mkdir('./results')
torchvision.utils.save_image(result_img, './results/val_epoch{}.jpg'.format(epoch))
(四)test
from pix2Topix import pix2pixG_256
import torch
import torchvision.transforms as transform
import matplotlib.pyplot as plt
import cv2
from PIL import Image
def test(img_path):
if img_path.endswith('.png'):
img = cv2.imread(img_path)
img = img[:, :, ::-1]
else:
img = Image.open(img_path)
transforms = transform.Compose([
transform.ToTensor(),
transform.Resize((256, 256)),
transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img = transforms(img.copy())
img = img[None].to('cuda') # [1,3,128,128]
# 实例化网络
G = pix2pixG_256().to('cuda')
# 加载预训练权重
ckpt = torch.load('weights/pix2pix_256.pth')
G.load_state_dict(ckpt['G_model'], strict=False)
G.eval()
out = G(img)[0]
out = out.permute(1,2,0)
out = (0.5 * (out + 1)).cpu().detach().numpy()
plt.figure()
plt.imshow(out)
plt.show()
if __name__ == '__main__':
test('../base/cmp_b0141.png')
五、结果
(一)128*128
利用128*128的net迭代了200个epoch,训练损失如下图所示。

下图为200个epoch后G生成的图片(验证集上)

(二)256*256
利用256*256的net迭代了170个epoch,训练损失如下图所示。

下图为训练集上G生成的部分图片(训练集上生成效果还是很不错的)。

下图为验证集上G生成的部分图片(验证集上效果相对较差一些)。

六、完整代码
代码:代码 提取码:1uee
权重:weights 提取码:zmdo
数据集:data 提取码:yydk