DnCNN代码学习—data_generator.py DnCNN代码学习—data_generator.py
⼀、源代码+注释
# -*- coding: utf-8 -*-
# =============================================================================
# @article{zhang2017beyond,
# title={Beyond a {Gaussian} denoir: Residual learning of deep {CNN} for image denoising},
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
# journal={IEEE Transactions on Image Processing},
# year={2017},
# volume={26},
# number={7},
# pages={3142-3155},
# }
# by Kai Zhang (08/2018)
#
# /cszn
# modified on the code /SaoYan/DnCNN-PyTorch
# =============================================================================
# no need to run this code parately
import matplotlib.pyplot as plt
import operator
import glob ##⽂件名操作模块glob
import cv2 #读取图像⾸先要导⼊OpenCV包
import numpy as np #
# from multiprocessing import Pool
from torch.utils.data import Datat #torch.utils.data.Datat 是⼀个表⽰数据集的抽象类
import torch ##包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作。
from torch.utils.data import DataLoader
patch_size, stride = 4, 1 #补丁⼤⼩ 40 步长10
aug_times = 1
scales = [1, 0.9, 0.8, 0.7]
batch_size = 128 #批量⼤⼩
#加噪声类
class DenoisingDatat(Datat):
"""Datat wrapping tensors.头发用英语怎么写
Arguments:
xs (Tensor): clean image patches
sigma: noi level, e.g., 25
"""
def __init__(lf, xs):
super(DenoisingDatat, lf).__init__()
lf.xs = xs #清洁图像
def __getitem__(lf, index):
batch_x = lf.xs[index]
#torch.randn:返回⼀个张量,包含了从区间[0,1)的均匀分布中抽取的⼀组随机数,形状由可变参数sizes 定义
#在PyTorch中,数学运算有in-place和none-in-place两种形式。 #in-place,就是计算结果替换原始内存中的值相乘:mul_ #noi = torch.randn(batch_x.size()).mul_(lf.sigma/255.0)
noi = torch.randn(batch_x.size()).mul_( np.random.randint(55)/255.0)
print('noi.shape',noi.shape)
print('noi.shape',noi.shape)
batch_y = batch_x + noi #加噪声
return batch_y, batch_x #返回批量batch_y, batch_x
def __len__(lf):
print(lf.xs. size(0))
return lf.xs. size(0) #xs.size(0)指batchsize的值
#展⽰图⽚
def show(x, title=None, cbar=Fal, figsize=None):
import matplotlib.pyplot as plt
# # 图像的长和宽(英⼨) #Figure返回的实例也将传递给后端的new_figure_manage
plt.figure(figsize=figsize)
# #interpolation 插值⽅法 #cmap: 颜⾊图谱(colormap), 默认绘制为RGB(A)颜⾊空间
plt.imshow(x, interpolation='nearest', cmap='gray')
朋友圈早安激励语if title:
plt.title(title)
if cbar:
plt.show() #输出图⽚
def data_aug(img, mode=0):
# data augmentation #数据增强
if mode == 0: #返回原图
return img
elif mode == 1: #翻转变换(flip): 沿着⽔平或者垂直⽅向翻转图像 #flipud(a) 上下翻转
return np.flipud(img)
elif mode == 2: #将矩阵A逆时针旋转90°以后返回
90(img)
elif mode == 3: #先反转再旋转
return np.90(img))
elif mode == 4: #将矩阵逆时针旋转(90×k)°以后返回,k取负数时表⽰顺时针旋转,再翻转
90(img, k=2)
elif mode == 5: #先旋转再翻转
return np.90(img, k=2))
elif mode == 6: #将矩阵逆时针旋转(90×k)°以后返回,k取负数时表⽰顺时针旋转
90(img, k=3)
elif mode == 7: #先旋转再翻转
return np.90(img, k=3))
#从⼀张图像中获取多尺度的补丁
def gen_patches(file_name):
# get multiscale patches from a single image
#使⽤opencv读取图像,直接返回numpy.ndarray 对象,通道顺序为BGR ,注意是BGR,通道值默认范围0-255 # flag = 0 ⼋位深度 1通道位深度指的是存储每个像素所⽤的位数,主要⽤于存储
img = cv2.imread(file_name, 0) # gray scale
plt.show()
print('img:',img)
h, w = img.shape
print('h,w',h,w,)
patches = []
for s in scales:
h_scaled, w_scaled = int(h*s), int(w*s)
print('h_scaled, w_scaled:',h_scaled, w_scaled)
# 图像缩放使⽤size时,参数输⼊是宽×⾼×通道 INTER_CUBIC:4x4像素邻域的双三次插值缩⼩图像 img_scaled = size(img, (w_scaled,h_scaled ), interpolation=cv2.INTER_CUBIC)
print(img_scaled.shape)
# extract patches
for i in range(0, h_scaled-patch_size+1, stride):
for j in range(0, w_scaled-patch_size+1, stride):
for j in range(0, w_scaled-patch_size+1, stride):
x = img_scaled[i:i+patch_size, j:j+patch_size]
for k in range(0, aug_times):
#调⽤数据增强⾃定义函数 random.randint产⽣ 0 到 8 的⼀个整数型随机数 x_aug = data_aug(x, mode=np.random.randint(0, 8))
#print('x_aug',x_aug)
# #print(i,j)
patches.append(x_aug)
print('patches.shape()',len(patches))
坏老街
#返回补丁值
print('patches:',patches)
print('patches.shape()',len(patches))
return patches #patches 是列表类型的
#从数据集中⽣成⼲净的补丁
def datagenerator(data_dir='testdata', verbo=True):
# generate clean patches from a datat
file_list = glob.glob(data_dir+'/*.png') # 得到⽂件列表get name list of all .png files
科学发展观的意义
# initrialize
data = []
# generate patches
for i in range(len(file_list)):
#调⽤⾃定义函数gen_patches
patches = gen_patches(file_list[i])
print('调⽤gen_patches结束')
count= 0
for patch in patches:
#print(patch)
count = count+1
data.append(patch)
print('count?',count)
print('data.len',len(data))
if verbo:
print(str(i+1) + '/' + str(len(file_list)) + ' is done ^_^')
#转换数据类型⽆符号整数(0到255)
弛怎么读print(operator.eq(patches,data))
#print('data:',data)
data = np.array(data, dtype='uint8')
#print(data)
print('data.shape',data.shape)
#np.expand_dims 扩展维度
年轮菜馆
data = np.expand_dims(data, axis=3)
print('data.shape',data.shape)
print('len(data)',len(data))
discard_n = len(data)-len(data)//batch_size*batch_size # becau of batch namalization print('discard_n',discard_n)
#delete是可以删除数组的整⾏和整列的
data = np.delete(data, range(discard_n), axis=0)
print(data.shape,len(data))
print('^_^-training data finished-^_^')
return data
if __name__ == '__main__':
data = datagenerator(data_dir='testdata')
print(data.shape)
#print(data)
# print('Shape of result = ' + str(res.shape))
# print('Shape of result = ' + str(res.shape))
# print('')
# if not ists(save_dir):
# os.mkdir(save_dir)
历史资料
# np.save(save_dir+'clean_patches.npy', res)
# print('Done.')
data = data.astype('float32')/255.0 #对数据进⾏处理,位于【0 1】
说文解字注
print('data.shape',data.shape)
#torch.from_numpy将numpy.ndarray 转换为pytorch的 Tensor。 transpo多维数组转置
data = torch.from_anspo((0, 3, 1, 2))) # tensor of the clean patches, N X C X H X W
print('data.shape',data.shape)
print(data)
#加噪声函数
DDatat=DenoisingDatat(data)
#DLoader = DataLoader(datat=DDatat, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
⼆、为了更好理解数据⽣成的过程
1、将patch_size⼤⼩更改为4*4,步长stride修改为1。图⽚⼤⼩利⽤画图⼯具将图⽚⼤⼩修改为h*w=13*12,只选区⼀张图⽚存放在⽂件夹testdata⾥⾯。主要为测试class DenoisingDatat(Data
t)和datagenerator函数。为了更好理解过程,增加了许多print()语句。
2、修改下⾯⼀处代码,size时,参数输⼊是 宽×⾼×通道,若输⼊(h_scaled ,w_scaled ),则会出现错误提⽰:ValueError: tting an array element with a quence。错误原因是⾼宽不等时,不正确的参数顺序会导致⽣成的patch_size⼩于4*4,所以不可以转换为数组。这种错误再⾼宽相同的情况下不会发⽣。
# 图像缩放使⽤size时,参数输⼊是宽×⾼×通道 INTER_CUBIC:4x4像素邻域的双三次插值缩⼩图像
img_scaled = size(img, (w_scaled,h_scaled ), interpolation=cv2.INTER_CUBIC)
3、测试数据:数据原始⼤⼩为h*w=500*500,选择图像⽂件打开⽅式为画图⼯具,点击重新调整⼤⼩—>像素点—>取消‘保持纵横⽐’—>修改⼤⼩,保存。
三、模块运⾏过程(部分)
在模块if_name_=='__main__'先调⽤datagenerator()函数,对testdata⽂件夹⾥⾯的每⼀张图⽚,都调⽤gen_patches()函数。我们⽂件夹只有⼀张图⽚,现在进⼊gen_patches()函数。执⾏如下代码。
def gen_patches(file_name):
# get multiscale patches from a single image
#使⽤opencv读取图像,直接返回numpy.ndarray 对象,通道顺序为BGR ,注意是BGR,通道值默认范围0-255
# flag = 0 ⼋位深度 1通道位深度指的是存储每个像素所⽤的位数,主要⽤于存储
img = cv2.imread(file_name, 0) # gray scale
plt.show()
print('img:',img)
h, w = img.shape
print('h,w',h,w,)
patches=[ ]⽤于存储⽣成的patch,patches是列表类型。patch⼤⼩4*4,stride步长为1,并且对图像进⾏处理,获得不同尺度的图像,对获得的图像矩阵进⾏遍历,每获得⼀个patch对其进⾏数据增强处理,调⽤data_aug()函数,旋转和翻转操作。将⽣成的patch存放在patches[]。
patches = []
for s in scales:
h_scaled, w_scaled = int(h*s), int(w*s)
print('h_scaled, w_scaled:',h_scaled, w_scaled)
# 图像缩放使⽤size时,参数输⼊是宽×⾼×通道 INTER_CUBIC:4x4像素邻域的双三次插值缩⼩图像
img_scaled = size(img, (w_scaled,h_scaled ), interpolation=cv2.INTER_CUBIC)
print(img_scaled.shape)
# extract patches
for i in range(0, h_scaled-patch_size+1, stride):
for j in range(0, w_scaled-patch_size+1, stride):
x = img_scaled[i:i+patch_size, j:j+patch_size]
for k in range(0, aug_times):
#调⽤数据增强⾃定义函数 random.randint产⽣ 0 到 8 的⼀个整数型随机数
x_aug = data_aug(x, mode=np.random.randint(0, 8))
#print('x_aug',x_aug)
# #print(i,j)
patches.append(x_aug)
print('patches.shape()',len(patches))
#返回补丁值
print('patches:',patches)
print('patches.shape()',len(patches))
return patches #patches 是列表类型的