本博客是作者复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的训练数据集读取代码的笔记。
一、函数test_trainSet()
函数功能:测试类trainDataSetFromTrack2的功能。
1、给出输入原始图像的路径和引导图像路径
1
2
origin_img_path = '../datasets/alltrain/*.png'# 输入的原始图像的路径
guide_img_path = origin_img_path # 引导图像路径
2、根据图片路径和想获取的图片数量获取数据集。
1
dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)
3、用DataLoader获取可以输入神经网络中的数据集
1
trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
4、得到一组样本图片,iter函数将可序列化的对象序列化,next按顺序取序列化后对象的数据。
1
batchdict = next(iter(trainloader))
5、获取原始图像及其深度图、引导图像及其深度图。
1
ori_image, guide_image, ori_depth, guide_depth = batchdict['x']
6、将图片保存到对象路径中
1
save_img(ori_image,'./1.png')
函数代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def test_trainSet():
# 创建数据集
origin_img_path = '../datasets/alltrain/*.png'# 输入的原始图像的路径
guide_img_path = origin_img_path # 引导图像路径
dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)# 根据图片路径读取数据集
trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
# 输出信息
print("训练集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(dataset), 1, len(trainloader), 1))
batchdict = next(iter(trainloader))# 得到一组样本数据
ori_image, guide_image, ori_depth, guide_depth = batchdict['x']
img_name = batchdict['img_name']
print(ori_image.shape)
print(guide_image.shape)
print(ori_depth.shape)
print(guide_depth.shape)
print('img_name', img_name)
save_img(ori_image,'./1.png')
二、类trainDataSetFromTrack2
类trainDataSetFromTrack2的功能:实现加载数据集所需的各个函数。
1、类头
该类继承自类Dataset,需要重载函数__init__()、getitem(self, index)、len(self)(这三个函数开头结尾都有两个下划线,typora文档里没显示出来)。
1
class trainDataSetFromTrack2(Dataset):
2、成员函数__init__()
函数代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def __init__(self,
origin_img_path: str, # 输入文件所在的路径
guide_img_path: str, # 输出文件所在的路径
num:int,# 读取的图片数量
):
super(trainDataSetFromTrack2, self).__init__()
# 获取所有图片的路径
self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path)
self.len = len(self.origin_img_paths)
# 选取指定数量的图片
if num > 0 and num < self.len:
self.origin_img_paths = self.origin_img_paths[:num]
self.guide_img_paths =self.guide_img_paths[:num]
self.len = num
# 获取图像预处理函数
self.preprocess_fn = data_transform
print(f'含有{self.len} 个样本的数据集已被创建')
函数功能:
1、获取所有输入的原始图像和引导图像的路径
1
self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path)
2、获取读取整个数据集的大小
1
self.len = len(self.origin_img_paths)
3、获取指定数量的图片
1
2
3
4
if num > 0 and num < self.len:
self.origin_img_paths = self.origin_img_paths[:num]
self.guide_img_paths =self.guide_img_paths[:num]
self.len = num
4、获取图像预处理函数
1
self.preprocess_fn = data_transform
3、成员函数__getitem__()
函数代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 获取一组图片数据
def __getitem__(self, index):
# 获取一组样本的路径
origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len]
origin_depth_name = origin_img_path.split('_')[0]+'.npy' # 拼接出原始图像对应深度图的路径:Image000+.npy
guide_depth_name = guide_img_path.split('_')[0]+'.npy' # 拼接出指导图像对应深度图的路径: Image001+.npy
truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]# 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀
# 读取该组样本的RGB图片
ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))
# 读取该组样本的depth图片
ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name))
# 获取该组样本对应的名称
img_name = origin_img_path.split('\\')[1]
return {'x':(ori_image, guide_image, ori_depth, guide_depth),
'y':truth_img,
'img_name':img_name}
函数功能:根据序号index,获取一组样本图片。
1、获取原始图像及其深度图、引导图像及其深度图、真实图像的路径
1
2
3
4
5
6
7
8
# 根据序号index,获取原始图像、引导图像的路径
origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len]
# 拼接出原始图像对应深度图的路径:Image000+.npy
origin_depth_name = origin_img_path.split('_')[0]+'.npy'
# 拼接出指导图像对应深度图的路径: Image001+.npy
guide_depth_name = guide_img_path.split('_')[0]+'.npy'
# 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀
truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]
2、# 读取该组样本的RGB图片
1
ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))
map()相当于调用了函数self._read_rgb_img三次,以上代码还可以写为
1
2
3
ori_image = self._read_rgb_img(origin_img_path)
guide_image = self._read_rgb_img(guide_img_path)
truth_img = self._read_rgb_img(truth_img_name)
3、读取该组样本的depth图片
1
ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name))
4、返回读取的这组样本图片
1
2
3
return {'x':(ori_image, guide_image, ori_depth, guide_depth),
'y':truth_img,
'img_name':img_name}
4、成员函数 __len__()
函数功能:返回读取图片的数量。
函数代码:
1
2
def __len__(self):
return self.len
5、成员函数_read_rgb_img()
类中的成员函数加上一个下划线_,这样类外就不能访问该函数。
函数功能:根据给定的图片路径,获取图片张量。
函数代码:
1
2
3
4
5
def _read_rgb_img(self,img_path):
img = Image.open(str(img_path)) # (1024,1024,4)
image_tensor = self.preprocess_fn(img) # tensor,size=(4,1024,1024)
image_tensor = image_tensor[:3, :, :] # tensor,size=(3,1024,1024)
return image_tensor
6、成员函数_read_depth_img()
函数功能:根据给定的图片路径,获取深度图片张量。
函数代码:
1
2
3
4
5
6
def _read_depth_img(self,depth_path):
depth = np.load(depth_path, allow_pickle=True).item()['normalized_depth']
ori_depth = torch.unsqueeze(torch.from_numpy(depth), 0) # 升维(1,1024,1024)
#ori_depth = torch.unsqueeze(ori_depth, 0) # 升维(1,1,1024,1024)
return ori_depth
7、成员函数_get_dataset_path()
函数功能:根据给定的图片文件夹的路径,获取图片文件夹中所有图片的路径。
glob.glob函数:搜索所有满足条件的项。
函数代码:
1
2
3
4
5
6
def _get_dataset_path(self, input_file_path, target_file_path):
origin_img_paths = sorted(glob.glob(input_file_path, recursive=True))
guide_img_paths = glob.glob(target_file_path, recursive=True)
random.shuffle(guide_img_paths)
#assert len(origin_img_paths) == len(guide_img_paths)
return origin_img_paths, guide_img_paths
三、数据增强手段
代码
1
2
3
data_transform = transforms.Compose([
transforms.ToTensor(),
])
四、函数save_img()
函数功能:把图片张量tensor_img保存到输出文件夹output_dir中。
函数代码:
1
2
3
def save_img(tensor_img,output_dir):
# 保存图像
torchvision.utils.save_image(tensor_img, output_dir)