首页 > 其他 > 详细

Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

时间:2021-07-20 23:21:24      阅读:11      评论:0      收藏:0      [点我收藏+]

Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

1.DataLoader

torch.utils.data.DataLoader():构建可迭代的数据装载器, 训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

Dataloader()参数:

  • dataset: Dataset类,决定数据从哪读取(数据路径)以及如何读取(做哪些预处理)
  • batchsize: 批大小
  • num_works: 是否采用多进程读取机制
  • shuffle: 每一个epoch是否乱序
  • drop_last: 当样本数不能被batchsize整除时,是否舍弃最后一批数据。

2. Dataset

torch.utils.data.Dataset():Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。

__getitem__方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 看上面的函数,参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。

2.1 ImageFolder

torchvision已经预先实现了常用的Dataset, 其他预先实现的有: torchvision.datasets.CIFAR10, 可以读取CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集。

ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:

ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

参数:

  • root: 图片路径
  • transform: 对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
  • target_transform:对label的转换
  • loader:给定路径后如何读取图片,默认读取为RGB格式的PIL Image对象

示例:

文件夹格式:

技术分享图片

train_path = r‘datasets/myDataSet/train‘

预处理格式:

train_transform = transforms.Compose([
    transforms.Resize((40,40)),
    transforms.RandomCrop(40,padding=4),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],
                         [0.229,0.224,0.225],)
])

dataset:

trainset = ImageFolder(train_path,transform = train_transform)
# print(trainset[30]) # 元组类型,第30号图片的(像素信息,label)

Data.DataLoader:

train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=False)
for i,(img, target) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一个batch图片对应的label

2.2

class myDataset(Data.Dataset):
    def __init__(self, path, transform):
        self.path = path
        self.transform = transform
        self.data_info = self.get_img_info(path)
        self.label = []
        for i in range(len(self.data_info)):
            self.label.append(list(self.data_info[i])[1])

    def __getitem__(self, idx):
        path_img = self.data_info[idx][0]
        label = self.label[idx]
        img = Image.open(path_img).convert(‘RGB‘)  # 0~255
        if self.transform is not None:
            img = self.transform(img)  # 在这里做transform,转为tensor等等
        return img, label, idx

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith(‘.jpg‘), img_names))
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = int(sub_dir)
                    data_info.append((path_img, int(label)))
        return data_info

trainset = myDataset(train_path, train_transform)

train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=True)
for i,(img, target, index) in enumerate(train_loader):
    print(i)
    print(img.shape) # (batchsize, channel, H, W)
    print(target.shape) # (batch)
    print(target) # 一个batch的图片对应的label
    print(index) #  一个batch的图片在数据集中对应的index

s

Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)

原文:https://www.cnblogs.com/bin888/p/15036953.html

(0)
(0)
   
举报
评论 一句话评论(0
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!