PyTorch 实现Image to Image (pix2pix)

目录

一、前言

二、数据集

三、网络结构

四、代码

     (一)net

     (二)dataset

     (三)train

     (四)test

五、结果

    (一)128*128

    (二)256*256       

六、完整代码


一、前言

        pix2pix可以实现画风迁移、图片修复、语义分割、图片上色等等。这里主要实现将facades labels转换成真实图片。由于源码中涉及的功能较多,代码较为复杂,这里我将主要部分拿了出来并做了一些小修改。

        在传统的GAN中,生成器G的输入是噪声z,目的是让G去学习一种分布去拟合真实数据的分布。而在pix2pix中,G的输入是facades labels,生成器G一方面要尽可能生成真的图片去骗过D,另一方面要使输出尽可能去逼近facades labels对应的真实图片(利用了L1loss)。G的噪声是通过在网络的dropout层引入的(如果不引入噪声,那么最终训练的G只会逼近已有的训练样本,无法拟合真实数据的分布)  。

        pix2pix中判别器D采用了PatchGAN,即D不仅仅只对整张图的真假打分,而是将图片最终分成多个patch,对每个patch进行打分,D只有当图片与fecades labels匹配并且尽可能逼近真实样本时候才会打高分。

二、数据集

        数据集采用的是facades数据集,训练样本是成对匹配的(其中fecades图片均以png格式存储,对应真实样本以jpg存储),共有606张图片,训练过程中将数据集按8:2的比例划分训练集和验证集。部分图像对如下图所示。

三、网络结构

        生成器G采用了Unet结构,本质为encoder-decoder,网络结构如下图所示(这里以输入图片为256*256为例)。dropout层只在输入输出channel都为512时候才有。

         判别器D其实就是简单的卷积神经网络,网络结构如下图所示。

四、代码

     (一)net

        这里提供了输入为128*128和256*256两种结构的网络

import torch.nn as nn
import torch
from collections import OrderedDict


# 定义降采样部分
class downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(downsample, self).__init__()
        self.down = nn.Sequential(
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.down(x)


# 定义上采样部分
class upsample(nn.Module):
    def __init__(self, in_channels, out_channels, drop_out=False):
        super(upsample, self).__init__()
        self.up = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(0.5) if drop_out else nn.Identity()
        )

    def forward(self, x):
        return self.up(x)


# ---------------------------------------------------------------------------------
# 定义pix_G  =>input 128*128
class pix2pixG_128(nn.Module):
    def __init__(self):
        super(pix2pixG_128, self).__init__()
        # down sample
        self.down_1 = nn.Conv2d(3, 64, 4, 2, 1)  # [batch,3,128,128]=>[batch,64,64,64]
        for i in range(7):
            if i == 0:
                self.down_2 = downsample(64, 128)  # [batch,64,64,64]=>[batch,128,32,32]
                self.down_3 = downsample(128, 256)  # [batch,128,32,32]=>[batch,256,16,16]
                self.down_4 = downsample(256, 512)  # [batch,256,16,16]=>[batch,512,8,8]
                self.down_5 = downsample(512, 512)  # [batch,512,8,8]=>[batch,512,4,4]
                self.down_6 = downsample(512, 512)  # [batch,512,4,4]=>[batch,512,2,2]
                self.down_7 = downsample(512, 512)  # [batch,512,2,2]=>[batch,512,1,1]

        # up_sample
        for i in range(7):
            if i == 0:
                self.up_1 = upsample(512, 512)  # [batch,512,1,1]=>[batch,512,2,2]
                self.up_2 = upsample(1024, 512, drop_out=True)  # [batch,1024,2,2]=>[batch,512,4,4]
                self.up_3 = upsample(1024, 512, drop_out=True)  # [batch,1024,4,4]=>[batch,512,8,8]
                self.up_4 = upsample(1024, 256)  # [batch,1024,8,8]=>[batch,256,16,16]
                self.up_5 = upsample(512, 128)  # [batch,512,16,16]=>[batch,128,32,32]
                self.up_6 = upsample(256, 64)  # [batch,256,32,32]=>[batch,64,64,64]

        self.last_Conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

        self.init_weight()

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_in')
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)

    def forward(self, x):
        # down
        down_1 = self.down_1(x)
        down_2 = self.down_2(down_1)
        down_3 = self.down_3(down_2)
        down_4 = self.down_4(down_3)
        down_5 = self.down_5(down_4)
        down_6 = self.down_6(down_5)
        down_7 = self.down_7(down_6)
        # up
        up_1 = self.up_1(down_7)
        up_2 = self.up_2(torch.cat([up_1, down_6], dim=1))
        up_3 = self.up_3(torch.cat([up_2, down_5], dim=1))
        up_4 = self.up_4(torch.cat([up_3, down_4], dim=1))
        up_5 = self.up_5(torch.cat([up_4, down_3], dim=1))
        up_6 = self.up_6(torch.cat([up_5, down_2], dim=1))
        out = self.last_Conv(torch.cat([up_6, down_1], dim=1))
        return out


# 定义pix_D_128    => input 128*128
class pix2pixD_128(nn.Module):
    def __init__(self):
        super(pix2pixD_128, self).__init__()

        # 定义基本的卷积\bn\relu
        def base_Conv_bn_lkrl(in_channels, out_channels, stride):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )

        D_dic = OrderedDict()
        in_channels = 6
        out_channels = 64
        for i in range(4):
            if i < 3:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
            else:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
            in_channels = out_channels
            out_channels *= 2
        D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)})  # [batch,1,14,14]
        self.D_model = nn.Sequential(D_dic)

    def forward(self, x1, x2):
        in_x = torch.cat([x1, x2], dim=1)
        return self.D_model(in_x)


# ---------------------------------------------------------------------------------
# 256*256
class pix2pixG_256(nn.Module):
    def __init__(self):
        super(pix2pixG_256, self).__init__()
        # down sample
        self.down_1 = nn.Conv2d(3, 64, 4, 2, 1)  # [batch,3,256,256]=>[batch,64,128,128]
        for i in range(7):
            if i == 0:
                self.down_2 = downsample(64, 128)  # [batch,64,128,128]=>[batch,128,64,64]
                self.down_3 = downsample(128, 256)  # [batch,128,64,64]=>[batch,256,32,32]
                self.down_4 = downsample(256, 512)  # [batch,256,32,32]=>[batch,512,16,16]
                self.down_5 = downsample(512, 512)  # [batch,512,16,16]=>[batch,512,8,8]
                self.down_6 = downsample(512, 512)  # [batch,512,8,8]=>[batch,512,4,4]
                self.down_7 = downsample(512, 512)  # [batch,512,4,4]=>[batch,512,2,2]
                self.down_8 = downsample(512, 512)  # [batch,512,2,2]=>[batch,512,1,1]

        # up_sample
        for i in range(7):
            if i == 0:
                self.up_1 = upsample(512, 512)  # [batch,512,1,1]=>[batch,512,2,2]
                self.up_2 = upsample(1024, 512, drop_out=True)  # [batch,1024,2,2]=>[batch,512,4,4]
                self.up_3 = upsample(1024, 512, drop_out=True)  # [batch,1024,4,4]=>[batch,512,8,8]
                self.up_4 = upsample(1024, 512)  # [batch,1024,8,8]=>[batch,512,16,16]
                self.up_5 = upsample(1024, 256)  # [batch,1024,16,16]=>[batch,256,32,32]
                self.up_6 = upsample(512, 128)  # [batch,512,32,32]=>[batch,128,64,64]
                self.up_7 = upsample(256, 64)  # [batch,256,64,64]=>[batch,64,128,128]

        self.last_Conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

        self.init_weight()

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_in')
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)

    def forward(self, x):
        # down
        down_1 = self.down_1(x)
        down_2 = self.down_2(down_1)
        down_3 = self.down_3(down_2)
        down_4 = self.down_4(down_3)
        down_5 = self.down_5(down_4)
        down_6 = self.down_6(down_5)
        down_7 = self.down_7(down_6)
        down_8 = self.down_8(down_7)
        # up
        up_1 = self.up_1(down_8)
        up_2 = self.up_2(torch.cat([up_1, down_7], dim=1))
        up_3 = self.up_3(torch.cat([up_2, down_6], dim=1))
        up_4 = self.up_4(torch.cat([up_3, down_5], dim=1))
        up_5 = self.up_5(torch.cat([up_4, down_4], dim=1))
        up_6 = self.up_6(torch.cat([up_5, down_3], dim=1))
        up_7 = self.up_7(torch.cat([up_6, down_2], dim=1))
        out = self.last_Conv(torch.cat([up_7, down_1], dim=1))
        return out


# 256*256
class pix2pixD_256(nn.Module):
    def __init__(self):
        super(pix2pixD_256, self).__init__()

        # 定义基本的卷积\bn\relu
        def base_Conv_bn_lkrl(in_channels, out_channels, stride):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )

        D_dic = OrderedDict()
        in_channels = 6
        out_channels = 64
        for i in range(4):
            if i < 3:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
            else:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
            in_channels = out_channels
            out_channels *= 2
        D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)})  # [batch,1,30,30]
        self.D_model = nn.Sequential(D_dic)

    def forward(self, x1, x2):
        in_x = torch.cat([x1, x2], dim=1)
        return self.D_model(in_x)

     (二)dataset

        split_data.py

import random
import glob


def split_data(dir_root):
    random.seed(0)
    ori_img = glob.glob(dir_root + '/*.png')
    k = 0.2
    train_ori_imglist = []
    val_ori_imglist = []
    sample_data = random.sample(population=ori_img, k=int(k * len(ori_img)))
    for img in ori_img:
        if img in sample_data:
            val_ori_imglist.append(img)
        else:
            train_ori_imglist.append(img)
    return train_ori_imglist, val_ori_imglist


if __name__ == '__main__':
    a, b= split_data('../base')

         mydatasets.py

from torch.utils.data.dataset import Dataset
import torchvision.transforms as transform
from PIL import Image
import cv2


class CreateDatasets(Dataset):
    def __init__(self, ori_imglist,img_size):
        self.ori_imglist = ori_imglist
        self.transform = transform.Compose([
            transform.ToTensor(),
            transform.Resize((img_size, img_size)),
            transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.ori_imglist)

    def __getitem__(self, item):
        ori_img = cv2.imread(self.ori_imglist[item])
        ori_img = ori_img[:, :, ::-1]
        real_img = Image.open(self.ori_imglist[item].replace('.png', '.jpg'))
        ori_img = self.transform(ori_img.copy())
        real_img = self.transform(real_img)
        return ori_img, real_img

     (三)train

from torch.utils.tensorboard import SummaryWriter
from pix2Topix import pix2pixG_256, pix2pixD_256
import argparse
from mydatasets import CreateDatasets
from split_data import split_data
import os
from torch.utils.data.dataloader import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
from utils import train_one_epoch, val


def train(opt):
    batch = opt.batch
    data_path = opt.dataPath
    print_every = opt.every
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    epochs = opt.epoch
    img_size = opt.imgsize

    if not os.path.exists(opt.savePath):
        os.mkdir(opt.savePath)

    # 加载数据集
    train_imglist, val_imglist = split_data(data_path)
    train_datasets = CreateDatasets(train_imglist, img_size)
    val_datasets = CreateDatasets(val_imglist, img_size)

    train_loader = DataLoader(dataset=train_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
                              drop_last=True)
    val_loader = DataLoader(dataset=val_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
                            drop_last=True)

    # 实例化网络
    pix_G = pix2pixG_256().to(device)
    pix_D = pix2pixD_256().to(device)

    # 定义优化器和损失函数
    optim_G = optim.Adam(pix_G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_D = optim.Adam(pix_D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    start_epoch = 0

    # 加载预训练权重
    if opt.weight != '':
        ckpt = torch.load(opt.weight)
        pix_G.load_state_dict(ckpt['G_model'], strict=False)
        pix_D.load_state_dict(ckpt['D_model'], strict=False)
        start_epoch = ckpt['epoch'] + 1

    writer = SummaryWriter('train_logs')
    # 开始训练
    for epoch in range(start_epoch, epochs):
        loss_mG, loss_mD = train_one_epoch(G=pix_G, D=pix_D, train_loader=train_loader,
                                           optim_G=optim_G, optim_D=optim_D, writer=writer, loss=loss, device=device,
                                           plot_every=print_every, epoch=epoch, l1_loss=l1_loss)

        writer.add_scalars(main_tag='train_loss', tag_scalar_dict={
            'loss_G': loss_mG,
            'loss_D': loss_mD
        }, global_step=epoch)

        # 验证集
        val(G=pix_G, D=pix_D, val_loader=val_loader, loss=loss, l1_loss=l1_loss, device=device, epoch=epoch)
        # 保存模型
        torch.save({
            'G_model': pix_G.state_dict(),
            'D_model': pix_D.state_dict(),
            'epoch': epoch
        }, './weights/pix2pix_256.pth')


def cfg():
    parse = argparse.ArgumentParser()
    parse.add_argument('--batch', type=int, default=16)
    parse.add_argument('--epoch', type=int, default=200)
    parse.add_argument('--imgsize', type=int, default=256)
    parse.add_argument('--dataPath', type=str, default='../base', help='data root path')
    parse.add_argument('--weight', type=str, default='', help='load pre train weight')
    parse.add_argument('--savePath', type=str, default='./weights', help='weight save path')
    parse.add_argument('--numworker', type=int, default=4)
    parse.add_argument('--every', type=int, default=2, help='plot train result every * iters')
    opt = parse.parse_args()
    return opt


if __name__ == '__main__':
    opt = cfg()
    print(opt)
    train(opt)

       utils.py 

import torchvision
from tqdm import tqdm
import torch
import os


def train_one_epoch(G, D, train_loader, optim_G, optim_D, writer, loss, device, plot_every, epoch, l1_loss):
    pd = tqdm(train_loader)
    loss_D, loss_G = 0, 0
    step = 0
    G.train()
    D.train()
    for idx, data in enumerate(pd):
        in_img = data[0].to(device)
        real_img = data[1].to(device)
        # 先训练D
        fake_img = G(in_img)
        D_fake_out = D(fake_img.detach(), in_img).squeeze()
        D_real_out = D(real_img, in_img).squeeze()
        ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
        ls_D2 = loss(D_real_out, torch.ones(D_real_out.size()).cuda())
        ls_D = (ls_D1 + ls_D2) * 0.5

        optim_D.zero_grad()
        ls_D.backward()
        optim_D.step()

        # 再训练G
        fake_img = G(in_img)
        D_fake_out = D(fake_img, in_img).squeeze()
        ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
        ls_G2 = l1_loss(fake_img, real_img)
        ls_G = ls_G1 + ls_G2 * 100

        optim_G.zero_grad()
        ls_G.backward()
        optim_G.step()

        loss_D += ls_D
        loss_G += ls_G

        pd.desc = 'train_{} G_loss: {} D_loss: {}'.format(epoch, ls_G.item(), ls_D.item())
        # 绘制训练结果
        if idx % plot_every == 0:
            writer.add_images(tag='train_epoch_{}'.format(epoch), img_tensor=0.5 * (fake_img + 1), global_step=step)
            step += 1
    mean_lsG = loss_G / len(train_loader)
    mean_lsD = loss_D / len(train_loader)
    return mean_lsG, mean_lsD


@torch.no_grad()
def val(G, D, val_loader, loss, device, l1_loss, epoch):
    pd = tqdm(val_loader)
    loss_D, loss_G = 0, 0
    G.eval()
    D.eval()
    all_loss = 10000
    for idx, item in enumerate(pd):
        in_img = item[0].to(device)
        real_img = item[1].to(device)
        fake_img = G(in_img)
        D_fake_out = D(fake_img, in_img).squeeze()
        ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
        ls_D = ls_D1 * 0.5
        ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
        ls_G2 = l1_loss(fake_img, real_img)
        ls_G = ls_G1 + ls_G2 * 100
        loss_G += ls_G
        loss_D += ls_D
        pd.desc = 'val_{}: G_loss:{} D_Loss:{}'.format(epoch, ls_G.item(), ls_D.item())

        # 保存最好的结果
        all_ls = ls_G + ls_D
        if all_ls < all_loss:
            all_loss = all_ls
            best_image = fake_img
    result_img = (best_image + 1) * 0.5
    if not os.path.exists('./results'):
        os.mkdir('./results')

    torchvision.utils.save_image(result_img, './results/val_epoch{}.jpg'.format(epoch))

     (四)test

from pix2Topix import pix2pixG_256
import torch
import torchvision.transforms as transform
import matplotlib.pyplot as plt
import cv2
from PIL import Image

def test(img_path):
    if img_path.endswith('.png'):
        img = cv2.imread(img_path)
        img = img[:, :, ::-1]
    else:
        img = Image.open(img_path)

    transforms = transform.Compose([
        transform.ToTensor(),
        transform.Resize((256, 256)),
        transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    img = transforms(img.copy())
    img = img[None].to('cuda')  # [1,3,128,128]

    # 实例化网络
    G = pix2pixG_256().to('cuda')
    # 加载预训练权重
    ckpt = torch.load('weights/pix2pix_256.pth')
    G.load_state_dict(ckpt['G_model'], strict=False)

    G.eval()
    out = G(img)[0]
    out = out.permute(1,2,0)
    out = (0.5 * (out + 1)).cpu().detach().numpy()
    plt.figure()
    plt.imshow(out)
    plt.show()


if __name__ == '__main__':
    test('../base/cmp_b0141.png')

五、结果

    (一)128*128

        利用128*128的net迭代了200个epoch,训练损失如下图所示。

        下图为200个epoch后G生成的图片(验证集上) 

    (二)256*256       

        利用256*256的net迭代了170个epoch,训练损失如下图所示。 

         下图为训练集上G生成的部分图片(训练集上生成效果还是很不错的)。

        下图为验证集上G生成的部分图片(验证集上效果相对较差一些)。

六、完整代码

        代码:代码       提取码:1uee

        权重:weights  提取码:zmdo

        数据集:data    提取码:yydk