torch.utils.data.DataLoader():构建可迭代的数据装载器, 训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。
Dataloader()参数:
torch.utils.data.Dataset():Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。
__getitem__方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 看上面的函数,参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。
torchvision已经预先实现了常用的Dataset, 其他预先实现的有: torchvision.datasets.CIFAR10,  可以读取CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集。
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
参数:
示例:
文件夹格式:
train_path = r‘datasets/myDataSet/train‘
预处理格式:
train_transform = transforms.Compose([
    transforms.Resize((40,40)),
    transforms.RandomCrop(40,padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225],)
])
dataset:
trainset = ImageFolder(train_path,transform = train_transform)
# print(trainset[30]) # 元组类型,第30号图片的(像素信息,label)
Data.DataLoader:
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=False)
for i,(img, target) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一个batch图片对应的label
class myDataset(Data.Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.data_info = self.get_img_info(path)
        self.label = []
        for i in range(len(self.data_info)):
            self.label.append(list(self.data_info[i])[1])
    def __getitem__(self, idx):
        path_img = self.data_info[idx][0]
        label = self.label[idx]
        img = Image.open(path_img).convert(‘RGB‘)  # 0~255
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor等等
        return img, label, idx
    def __len__(self):
        return len(self.data_info)
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith(‘.jpg‘), img_names))
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = int(sub_dir)
                    data_info.append((path_img, int(label)))
        return data_info
trainset = myDataset(train_path, train_transform)
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=True)
for i,(img, target, index) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一个batch的图片对应的label
    print(index) #  一个batch的图片在数据集中对应的index
s
Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)
原文:https://www.cnblogs.com/bin888/p/15036953.html