首页 > 其他 > 详细

PyTorch笔记之 Dataset 和 Dataloader

时间:2019-11-27 21:06:06      阅读:548      评论:0      收藏:0      [点我收藏+]

一、简介

在 PyTorch 中,我们的数据集往往会用一个类去表示,在训练时用 Dataloader 产生一个 batch 的数据

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

比如官方例子中对 CIFAR10 图像数据集进行分类,就有用到这样的操作,具体代码如下所示

trainset = torchvision.datasets.CIFAR10(root=./data, train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=./data, train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

简单说,用 一个类 抽象地表示数据集,而 Dataloader 作为生成器,每次产生一个 batch 大小的数据,节省内存

二、Dataset

Dataset 是 PyTorch 中用来表示数据集的一个抽象类,我们的数据集可以用这个类来表示,至少覆写下面两个方法即可

这返回数据前可以进行适当的数据处理,比如将原文用一串数字序列表示

  • __len__:数据集大小
  • __getitem__:实现这个方法后,可以通过下标的方式( dataset[i] )的来取得第 $i$ 个数据

下面我们来为编写一个类表示一个情感二分类数据集,继续用苏神整理的数据集

https://github.com/bojone/bert4keras/tree/master/examples/datasets

from torch.utils.data import Dataset, DataLoader
import pandas as pd

class SentimentDataset(Dataset):
    def __init__(self, path_to_file):
        self.dataset = pd.read_csv(path_to_file, sep="\t", names=["text", "label"])
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        text = self.dataset.loc[idx, "text"]
        label = self.dataset.loc[idx, "label"]
        sample = {"text": text, "label": label}
        return sample

三、Dataloader

Dataloader 就是一个生成器,传入一个 Dataset 对象,它就会根据参数 batch_size 的值生成一个 batch 的数据

if __name__ == "__main__":
    sentiment_dataset = SentimentDataset("sentiment.test.data")
    sentiment_dataloader = DataLoader(sentiment_dataset, batch_size=4, shuffle=True, num_workers=2)
    for idx, batch_samples in enumerate(sentiment_dataloader):
        text_batchs, text_labels = batch_samples["text"], batch_samples["label"]
        print(text_batchs)

 

PyTorch笔记之 Dataset 和 Dataloader

原文:https://www.cnblogs.com/dogecheng/p/11930535.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!