使用PyTorch开始自己的首次视觉竞赛系列(1)--如何使用DataLoader

写这篇文章的初衷是希望在众多的开源baseline中,我想要形成我自己的一套pipeline。所以我将在近期的几个竞赛中开始尝试并逐渐整理出一套简洁易用的pipeline

请注意,本文编写于 1716 天前,最后修改于 1707 天前,其中某些信息可能已经过时。

使用PyTorch开始自己的首次视觉竞赛系列(1)--如何使用DataLoader

0 前言

写这篇文章的初衷是希望在众多的开源baseline中,我想要形成我自己的一套pipeline。所以我将在近期的几个竞赛中开始尝试并逐渐整理出一套简洁易用的pipeline。为后面的毕设工作也做好准备。我将为每个模块都进行部分知识的整理,并完成最后的系列博客

1 视觉模型的大体框架

  • 数据准备

    • DataLoader
    • Transform
  • 网络定义

    • 我们应该使用怎样的Backbone
    • 如何恰定义我们的优化器和损失
  • 训练及多折验证
  • TTA(Test Time Augmentation) 及预测最终结果

2 如何开始写DataLoader

2.1 使用 torch.utils.data.dataset.Dataset 来组织数据

将所有的数据集都表示为从 key 到 data samples 的映射,使用方法是继承该类,重写子类。所有的子类都必须重载 __len__() 函数以及 __getitem__() 函数,前者返回数据集的大小,后者将根据给定的整数 key 取得 data sample。

上面是PyTorch官方文档中所说到的,下面结合实例来讲述具体的使用。这里是我们的数据集格式,如下图所示

├─test           # 其中是测试集的图片,序号从0开始
└─train          # 其中是训练集的图片,序号从0开始
└─train.csv     # 训练集的标注,其格式是filename label两列数据
class MyDataset(data.Dataset):#需要继承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return 0
2.1.1 对 init() 函数进行重载

接下来我们需要对三个函数进行重载,首先对 init() 函数进行重载

  • 首先对于我们定义一个df用来存放train.csv的标注数据
  • 定义 transforms 用来对数据进行增强操作,在此处实现自己的增强策略。例如标准化和随机水平翻转等,
  • 定义 mode 变量用来区分训练集以及测试结合,默认是训练集
def __init__(self, df, transform, mode='train'):
    self.df = df
    self.transform = transform
    self.mode = mode
2.1.2 对 getitem() 函数进行重载

接着我们需要对 getitem() 函数进行重载

对训练集来说需要做下面几步,测试集则不用返回标签。除此之外,训练集的 df 是从 csv 文件中读取的,而测试集的 df 我们需要将测试集中的文件名组织为列表即可(可参照以下实例)

  • 用 Image 这个库将图片加载进来,并转换为 ‘RGB’ 格式
  • 进行数据增强
  • 返回的结果是增强后的图片以及图片的标签
def __getitem__(self, index):
    if self.mode == 'train':
        img = Image.open(self.df['filename'].iloc[index]).convert('RGB')
        img = self.transform(img)
        return img, torch.from_numpy(np.array(self.df['label'].iloc[index]))
    else:
        img = Image.open(self.df[index]).convert('RGB')
        img = self.transform(img)
        return img, torch.from_numpy(np.array(0))
    
# 测试集的组织方式实例
test_path_list = ['{}/{}.jpg'.format(config.image_test_path, x) for x in range(0, data_len)]
test_df = np.array(test_path_list)
2.1.3 对 len() 函数进行重载

len() 函数的重载很简单,返回 len(self.df) 即可

def __len__(self):
    return len(self.df)

2.2 使用 torchvision.datasets.ImageFolder 来组织数据

ImageFolder 组织数据比较局限,但我最开始使用的便是这种方式,其必须要求的数据格式如下:一个类的图像放在一个文件夹中。操作很不灵活,我刚开始用的时候花了一些时间将数据组织成要求的格式。现在还是重载DataLoader类会方便很多。

torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

2.3 使用 torch.utils.data.DataLoader 来加载数据

在数据组织完成之后,我们需要构建训练集、验证集和测试集的DataLoader。DataLoader 能够有效的帮你进行批量的数据迭代,下面是 DataLoader 的构造函数,下面简单解释一下常用的一些参数。

  • dataset:即上述的 DataSet 类或者 ImageFolder 类
  • batch_size:即进行批量训练的数据数量
  • shuffle:是否打乱,一般来说训练集为True,验证集和测试集为False
  • num_workers:即加载数据所用到的线程数量
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):
    pass

一般来说,我们定义好 DataLoader 后需要去测试 DataLoader 是否写对了,我们就用一个函数来测试一下,并解释一下 DataLoader 的迭代过程。我们有两种方法迭代DataLoader,如下面的代码:

for step, (batch_x, batch_y) in enumerate(dataloader):
    pass
for batch_x, batch_y in dataloader:
    pass

接下来则是完整的测试过程:

# 验证dataloader是否正确的取到数据
def dataloader_test(dataloader):
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(dataloader):
            print("单个batch的size: ", batch_x.shape)
            print("单个batch的label张量: ", batch_y.numpy())
            
            # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
            img = batch_x[0]                            
            img = img.numpy()                    # FloatTensor转为ndarray  
            img = np.transpose(img, (1, 2, 0))  # 把channel那一维放到最后

            plt.imshow(img)
            plt.show()
            break
        break

2.4 后言

至此,我们已经能够完成对数据集的读取,写出 DataLoader 的代码并测试自己编写代码的正确性了。在总结这篇文章的过程中我也体会到,其实 PyTorch 的官方文档真的写的超详细了。在编写整个系列中,我都会尽可能地多参考官方文档。下面我将对数据集很重要的一部分:数据增强进行单独的讲述,分享我在这个过程中学到的知识。

3. 参考资料

  1. PyTorch官方文档
  2. StackOverflow论坛
  3. 称霸Kaggle的十大深度学习技巧

评论列表