Posts 「项目复现」S3net的训练数据集读取
Post
Cancel

「项目复现」S3net的训练数据集读取

​ 本博客是作者复现《S3Net: A Single Stream Structure for Depth Guided Image Relighting》的训练数据集读取代码的笔记。

一、函数test_trainSet()

函数功能:测试类trainDataSetFromTrack2的功能。

1、给出输入原始图像的路径和引导图像路径

1
2
origin_img_path = '../datasets/alltrain/*.png'# 输入的原始图像的路径
guide_img_path = origin_img_path # 引导图像路径

2、根据图片路径和想获取的图片数量获取数据集。

1
dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)

3、用DataLoader获取可以输入神经网络中的数据集

1
trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)

4、得到一组样本图片,iter函数将可序列化的对象序列化,next按顺序取序列化后对象的数据。

1
batchdict = next(iter(trainloader))

5、获取原始图像及其深度图、引导图像及其深度图。

1
ori_image, guide_image, ori_depth, guide_depth = batchdict['x']

6、将图片保存到对象路径中

1
save_img(ori_image,'./1.png')

函数代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def test_trainSet():
    # 创建数据集
    origin_img_path = '../datasets/alltrain/*.png'# 输入的原始图像的路径
    guide_img_path = origin_img_path # 引导图像路径
    dataset = trainDataSetFromTrack2(origin_img_path, guide_img_path,10)# 根据图片路径读取数据集
    trainloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
    # 输出信息
    print("训练集一共有{}/{}={}个的批次,其中{}是mini-batch".format(len(dataset), 1, len(trainloader), 1))
    batchdict = next(iter(trainloader))# 得到一组样本数据
    ori_image, guide_image, ori_depth, guide_depth = batchdict['x']
    img_name = batchdict['img_name']
    print(ori_image.shape)
    print(guide_image.shape)
    print(ori_depth.shape)
    print(guide_depth.shape)
    print('img_name', img_name)
    save_img(ori_image,'./1.png')

二、类trainDataSetFromTrack2

类trainDataSetFromTrack2的功能:实现加载数据集所需的各个函数。

1、类头

该类继承自类Dataset,需要重载函数__init__()、getitem(self, index)、len(self)(这三个函数开头结尾都有两个下划线,typora文档里没显示出来)。

1
class trainDataSetFromTrack2(Dataset):

2、成员函数__init__()

函数代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def __init__(self,
                 origin_img_path: str,  # 输入文件所在的路径
                 guide_img_path: str,  # 输出文件所在的路径
                 num:int,# 读取的图片数量
                 ):
    super(trainDataSetFromTrack2, self).__init__()
    # 获取所有图片的路径
    self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path)
    self.len = len(self.origin_img_paths)
    # 选取指定数量的图片
    if num > 0 and num < self.len:
        self.origin_img_paths = self.origin_img_paths[:num]
        self.guide_img_paths =self.guide_img_paths[:num]
        self.len = num
    # 获取图像预处理函数
    self.preprocess_fn = data_transform
    print(f'含有{self.len} 个样本的数据集已被创建')

函数功能

1、获取所有输入的原始图像和引导图像的路径

1
self.origin_img_paths, self.guide_img_paths = self._get_dataset_path(origin_img_path, guide_img_path)

2、获取读取整个数据集的大小

1
self.len = len(self.origin_img_paths)

3、获取指定数量的图片

1
2
3
4
if num > 0 and num < self.len:
    self.origin_img_paths = self.origin_img_paths[:num]
    self.guide_img_paths =self.guide_img_paths[:num]
    self.len = num

4、获取图像预处理函数

1
self.preprocess_fn = data_transform

3、成员函数__getitem__()

函数代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 获取一组图片数据
    def __getitem__(self, index):
        # 获取一组样本的路径
        origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len]
        origin_depth_name = origin_img_path.split('_')[0]+'.npy'  # 拼接出原始图像对应深度图的路径:Image000+.npy
        guide_depth_name = guide_img_path.split('_')[0]+'.npy' # 拼接出指导图像对应深度图的路径: Image001+.npy
        truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]# 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀
        # 读取该组样本的RGB图片
        ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))

        # 读取该组样本的depth图片
        ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name))
        # 获取该组样本对应的名称
        img_name = origin_img_path.split('\\')[1]
        return {'x':(ori_image, guide_image, ori_depth, guide_depth),
                'y':truth_img,
                'img_name':img_name}

函数功能:根据序号index,获取一组样本图片。

1、获取原始图像及其深度图、引导图像及其深度图、真实图像的路径

1
2
3
4
5
6
7
8
# 根据序号index,获取原始图像、引导图像的路径
origin_img_path, guide_img_path = self.origin_img_paths[index % self.len], self.guide_img_paths[index % self.len]
# 拼接出原始图像对应深度图的路径:Image000+.npy
origin_depth_name = origin_img_path.split('_')[0]+'.npy' 
# 拼接出指导图像对应深度图的路径: Image001+.npy
guide_depth_name = guide_img_path.split('_')[0]+'.npy'
# 拼接出真实图像的路径:原始图像的前缀Image000+指导图像的后缀
truth_img_name = origin_img_path.split('_')[0]+'_'+guide_img_path.split('_')[1]+'_'+guide_img_path.split('_')[2]

2、# 读取该组样本的RGB图片

1
ori_image, guide_image,truth_img = map(self._read_rgb_img, (origin_img_path, guide_img_path,truth_img_name))

map()相当于调用了函数self._read_rgb_img三次,以上代码还可以写为

1
2
3
ori_image = self._read_rgb_img(origin_img_path)
guide_image = self._read_rgb_img(guide_img_path)
truth_img = self._read_rgb_img(truth_img_name)

3、读取该组样本的depth图片

1
ori_depth, guide_depth = map(self._read_depth_img, (origin_depth_name, guide_depth_name))

4、返回读取的这组样本图片

1
2
3
 return {'x':(ori_image, guide_image, ori_depth, guide_depth),
         'y':truth_img,
         'img_name':img_name}

4、成员函数 __len__()

函数功能:返回读取图片的数量。

函数代码:

1
2
  def __len__(self):
      return self.len

5、成员函数_read_rgb_img()

类中的成员函数加上一个下划线_,这样类外就不能访问该函数。

函数功能:根据给定的图片路径,获取图片张量。

函数代码:

1
2
3
4
5
def _read_rgb_img(self,img_path):
    img = Image.open(str(img_path))  # (1024,1024,4)
    image_tensor = self.preprocess_fn(img)  # tensor,size=(4,1024,1024)
    image_tensor = image_tensor[:3, :, :]  # tensor,size=(3,1024,1024)
    return image_tensor

6、成员函数_read_depth_img()

函数功能:根据给定的图片路径,获取深度图片张量。

函数代码:

1
2
3
4
5
6
def _read_depth_img(self,depth_path):
    depth = np.load(depth_path, allow_pickle=True).item()['normalized_depth']
    ori_depth = torch.unsqueeze(torch.from_numpy(depth), 0)  # 升维(1,1024,1024)
    #ori_depth = torch.unsqueeze(ori_depth, 0)  # 升维(1,1,1024,1024)
    return ori_depth

7、成员函数_get_dataset_path()

函数功能:根据给定的图片文件夹的路径,获取图片文件夹中所有图片的路径。

glob.glob函数:搜索所有满足条件的项。

函数代码:

1
2
3
4
5
6
def _get_dataset_path(self, input_file_path, target_file_path):
    origin_img_paths = sorted(glob.glob(input_file_path, recursive=True))
    guide_img_paths = glob.glob(target_file_path, recursive=True)
    random.shuffle(guide_img_paths)
    #assert len(origin_img_paths) == len(guide_img_paths)
    return origin_img_paths, guide_img_paths

三、数据增强手段

代码

1
2
3
data_transform = transforms.Compose([
    transforms.ToTensor(),
])

四、函数save_img()

函数功能:把图片张量tensor_img保存到输出文件夹output_dir中。

函数代码:

1
2
3
def save_img(tensor_img,output_dir):
    # 保存图像
    torchvision.utils.save_image(tensor_img, output_dir)
This post is licensed under CC BY 4.0 by the author.

「项目复现」S3net的训练代码实现

「项目复现」S3net的网络结构实现