Cycle-GAN代码解读
1 model.py⽂件
1.1 初始化函数
as nn
functional as F
import torch
# 初始化函数
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
if hasattr(m, "bias") and m.bias is not None:
elif classname.find("BatchNorm2d") != -1:
1.2 RESNET 模块定义
class ResidualBlock(nn.Module):
def __init__(lf, in_features):
super(ResidualBlock, lf).__init__()
lf.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def forward(lf, x):
return x + lf.block(x)
从⽣成器中截取⼀个resnet模块其结构如下所⽰。
1.3 模型定义
⽣成器定义:模型⼀上来就是3个“卷积块”,每个卷积块包含:⼀个2D卷积层,⼀个Instance Normalization层和⼀个ReLU。这3个“卷积块”是⽤来降采样的。然后是9个“残差块”,每个残差块包含2个卷积层,每个卷积层后⾯都有⼀个Instance Normalization 层,第⼀个Instance Normalization层后⾯是ReLU激活函数,这些使⽤残差连接。然后过3个“上采样块”,每个块包含⼀个2D转置卷积层,1个Instance Normalization和1个ReLU激活函数。最后⼀层是⼀个2D卷积层,使⽤tanh作为激活函数,该层⽣成的形状为
(256,256,3)的图像。这个Generator的输⼊和输出的⼤⼩是⼀摸⼀样的,都是(256,256,3)。
class GeneratorResNet(nn.Module):
def __init__(lf, input_shape, num_residual_blocks):
super(GeneratorResNet, lf).__init__()
channels = input_shape[0]
# Initial convolution block
# 初始化卷积模块
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
# 降采样 3个卷积模块
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
# resnet模块 num_residual_blocks=9
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling
# 上采样
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
# 输出层
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
def forward(lf, x):
del(x)
判别器定义:判别⽹络的架构类似于PatchGAN中的判别⽹络架构,是⼀个包含⼏个卷积块的深度卷积神经⽹络。
class Discriminator(nn.Module):
def __init__(lf, input_shape):
super(Discriminator, lf).__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
# 计算判别器输出的图⽚⼤⼩(PatchGAN)
lf.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
*discriminator_block(channels, 64, normalize=Fal),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(lf, img):
del(img)
判别器的结构如下所⽰:
2 datats.py⽂件
主要是ImageDatat函数的操作,__init__操作将trainA和trainB的路径读⼊files_A 和files_B;__getitem__对两个⽂件夹的图⽚进⾏读取,若不是RGB图⽚则进⾏转换;__len__返回两个⽂件夹数据数量的⼤值。
import os
from torch.utils.data import Datat
from PIL import Image
ansforms as transforms
# 转为rgb图⽚
def to_rgb(image):
rgb_image = w("RGB", image.size)
rgb_image.paste(image)
return rgb_image
# 对数据进⾏读取
class ImageDatat(Datat):
def __init__(lf, root, transforms_=None, unaligned=Fal, mode="train"):
lf.unaligned = unaligned
lf.files_A = sorted(glob.glob(os.path.join(root, "trainA") + "/*.*"))
lf.files_B = sorted(glob.glob(os.path.join(root, "trainB") + "/*.*"))
'''
lf.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
lf.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))
'''
def __getitem__(lf, index):
image_A = Image.open(lf.files_A[index % len(lf.files_A)])
if lf.unaligned:
image_B = Image.open(lf.files_B[random.randint(0, len(lf.files_B) - 1)]) el:
image_B = Image.open(lf.files_B[index % len(lf.files_B)])
# Convert grayscale images to rgb
if de != "RGB":
image_A = to_rgb(image_A)
if de != "RGB":
image_B = to_rgb(image_B)
item_A = lf.transform(image_A)
item_B = lf.transform(image_B)
return {"A": item_A, "B": item_B}
def __len__(lf):
return max(len(lf.files_A), len(lf.files_B))
3 utils.py⽂件
主要关注学习率衰减(LambdaLR)。
import datetime
import sys
from torch.autograd import Variable
import torch
import numpy as np
from torchvision.utils import save_image
class ReplayBuffer:
def __init__(lf, max_size=50):
asrt max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
lf.max_size = max_size
lf.data = []
def push_and_pop(lf, data):
to_return = []
for element in data.data:
element = torch.unsqueeze(element, 0)
if len(lf.data) < lf.max_size:
lf.data.append(element)
to_return.append(element)
el:
if random.uniform(0, 1) > 0.5:
i = random.randint(0, lf.max_size - 1)
to_return.append(lf.data[i].clone())
lf.data[i] = element
el:
to_return.append(element)
return Variable(torch.cat(to_return))
class LambdaLR:
def __init__(lf, n_epochs, offt, decay_start_epoch):
asrt (n_epochs - decay_start_epoch) > 0, "Decay must start before the training ssion ends!"
lf.n_epochs = n_epochs
lf.offt = offt
lf.decay_start_epoch = decay_start_epoch
def step(lf, epoch):
return 1.0 - max(0, epoch + lf.offt - lf.decay_start_epoch) / (lf.n_epochs - lf.decay_start_epoch)
4 cyclegan.py⽂件
4.1 导⼊相关库以及进⾏参数设置
导⼊相关库