3.3 数据读入
3.3 数据读入#
PyTorch数据读入是通过Dataset+DataLoader的方式完成的,Dataset定义好数据的格式和数据变换形式,DataLoader用iterative的方式不断读入批次数据。
经过本节的学习,你将收获:
PyTorch常见的数据读取方式
构建自己的数据读取流程
我们可以定义自己的Dataset类来实现灵活的数据读取,定义的类需要继承PyTorch自身的Dataset类。主要包含三个函数:
__init__
: 用于向类中传入外部参数,同时定义样本集__getitem__
: 用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__
: 用于返回数据集的样本数
下面以cifar10数据集为例给出构建Dataset类的方式:
import torch
from torchvision import datasets
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)
这里使用了PyTorch自带的ImageFolder类的用于读取按一定结构存储的图片数据(path对应图片存放的目录,目录下包含若干子目录,每个子目录对应属于同一个类的图片)。
其中data_transform
可以对图像进行一定的变换,如翻转、裁剪等操作,可自己定义。这里我们会在下一章通过实战加以介绍并在notebook中做了示例代码。
这里我们给出一个自己定制Dataset的例子
import os
import pandas as pd
from torchvision.io import read_image
class MyDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
"""
Args:
annotations_file (string): Path to the csv file with annotations.
img_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
target_transform (callable, optional): Optional transform to be applied
on the target.
"""
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
"""
Args:
idx (int): Index
"""
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
其中,我们的标签类似于以下的形式:
image1.jpg, 0
image2.jpg, 1
......
image9.jpg, 9
构建好Dataset后,就可以使用DataLoader来按批次读入数据了,实现代码如下:
from torch.utils.data import DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
其中:
batch_size:样本是按“批”读入的,batch_size就是每次读入的样本数
num_workers:有多少个进程用于读取数据,Windows下该参数设置为0,Linux下常见的为4或者8,根据自己的电脑配置来设置
shuffle:是否将读入的数据打乱,一般在训练集中设置为True,验证集中设置为False
drop_last:对于样本最后一部分没有达到批次数的样本,使其不再参与训练
这里可以看一下我们的加载的数据。PyTorch中的DataLoader的读取可以使用next和iter来完成
import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
plt.imshow(images[0].transpose(1,2,0))
plt.show()