使用PyTorch开始自己的首次视觉竞赛系列(2)--DataLoader中的Transform

在深度学习中,数据十分重要。在我们构造的网络较为庞大的情况下,相当于我们需要从假设空间中以数据驱动的方式学出一种相应的参数组。而网络越宽越深越复杂,往往其参数越多,所以我们需要更多的数据去逼近一个可用解,在本节中,我大致介绍一下我们能够为数据做点什么

请注意,本文编写于 1504 天前,最后修改于 1183 天前,其中某些信息可能已经过时。

使用PyTorch开始自己的首次视觉竞赛系列(2)--DataLoader中的Transform

0 前言

在深度学习中,数据十分重要。在我们构造的网络较为庞大的情况下,相当于我们需要从假设空间中以数据驱动的方式学出一种相应的参数组。而网络越宽越深越复杂,往往其参数越多,所以我们需要更多的数据去逼近一个可用解,在本节中,我大致介绍一下我们能够为数据做点什么

1 使用 torchvision.transforms 提供的数据增强操作

1.1 transforms.Compose

They can be chained together using Compose

Compose函数主要是将各个数据增强操作进行组合,文档中是以chain这个词来描述的。我也去看了源码,其实现的很简单,就是将compose前的各个增强操作依次过一遍。但是在加入过多的增强操作后,读取数据的速度会有显著下降(已验证)。我这里推测使用多线程来加载数据是有一定帮助的。

1.2 较常用的增强操作

由于我习惯用 PIL 库进行读入,所以这里只挑出常用的用于 PIL 图像的数据增强操作。其中数据增强操作我大致分为几类:裁剪操作、翻转旋转操作、图像变换操作。值得注意的是:在使用时我们需要考虑增强后图像的标注信息会不会改变,如果改变我们需要谨慎操作

1.2.1 裁剪操作:torchvision.transforms.RandomCrop()

class torchvision.transforms.RandomCrop(size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant')

该增强操作是对图像进行随机裁剪。其参数中,使用较多的为 size 及 padding

  • size:裁剪后的尺寸
  • padding:补全的像素数量
1.2.2 翻转操作:torchvision.transforms.RandomHorizontalFlip()

class torchvision.transforms.`RandomHorizontalFlip`(p=0.5)

该增强操作是对图像进行随机的水平翻转,其默认的概率时0.5。

1.2.3 旋转操作: torchvision.transforms.RandomRotation()

class torchvision.transforms.RandomRotation(degrees, resample=False, expand=False, center=None, fill=0)

该增强操作主要是对图像进行随机旋转,我们主要使用参数 degrees 来调整角度。

1.2.4 图像变换之Resize: torchvision.transforms.Resize()

class torchvision.transforms.Resize(size, interpolation=2)

该操作主要是重置图像的分辨率,size和之前的定义一样,interpolation 参数是指插值方式。默认为PIL.Image.BILINEAR方法。

1.2.5 图像变换之标准化:torchvision.transforms.Normalize()

class torchvision.transforms.Normalize(mean, std)

该增强操作主要是对图像进行标准化,至于为什么要进行标准化在第三点中有所阐述。

1.2.6 图像变换之转为张量:torchvision.transforms.ToTensor()`

class torchvision.transforms.ToTensor()

该操作是将 PIL Image 或者 ndarray 转换为 tensor,并且归一化至[0-1]

1.2.7 图像变换之变换亮度、对比度和饱和度:torchvision.transforms.ColorJitter()

class torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)

该操作主要是对图像的亮度、对比度和饱和度做一定的变换。这对一些特殊场景的图片或许有用:例如工业缺陷检测中的反光过多的图像

1.3 实例说明

我们按照如下的操作来进行我们的数据增强,一般来说数据增强是在构建Dataset 或者 DataLoader的时候进行

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_transform = transforms.Compose([
    transforms.Resize([299, 299]),
    transforms.RandomRotation(15),
    transforms.RandomChoice([transforms.Resize([256, 256]), transforms.CenterCrop([256, 256])]),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # range [0, 255] -> [0.0,1.0]
    normalize,
])

val_transform = transforms.Compose([
    transforms.Resize([config.img_size, config.img_size]),
    transforms.ToTensor(),
    normalize,
])

2 构建我们自己的数据增强函数

刚刚讲到的都是在训练过程中做数据的在线增强,即在DataLoader提取数据时才对数据做操作。此外,我们也可以构建我们自己的数据增强函数做离线数据增强,通常离线增强是在数据集较少的时候进行。这一块内容我准备在之后阅读了 AutoAugment 论文之后再进行补充

3. 一些问题

1 为什么需要对数据集进行normalize?

通过对数据集进行normalize可以放大数据中更为相异的部分

2 为什么进行normalize的时候推荐使用mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]这样的参数?

在PyTorch官方文档(链接)中有相应描述,意思是说,这组参数是ImageNet上的mean和std,如果你使用ImageNet预训练模型的话,PyTorch官方推荐使用这种归一化的方式进行你自己数据集的训练过程的。

image-20200310174845108
image-20200310174845108

3 在什么情况下使用ImageNet的mean和std是不合适的?

因为ImageNet的数据集多是自然影像,如果使用如工业缺陷检测,文字识别这样的任务时。ImageNet的数据分布与当前数据集并不一致,所以建议是计算自己数据集的mean以及std。但自己从头训练和使用预训练模型的性能孰优孰劣,还有待商榷

4. 参考资料

  1. PyTorch官方文档
  2. PyTorch关于transforms的文档

评论列表