在 PyTorch 中,我们的数据集往往会用一个类去表示,在训练时用 Dataloader 产生一个 batch 的数据
比如官方例子中对 CIFAR10 图像数据集进行分类,就有用到这样的操作,具体代码如下所示
trainset = torchvision.datasets.CIFAR10(root=‘./data‘, train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root=‘./data‘, train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
简单说,用 一个类 抽象地表示数据集,而 Dataloader 作为生成器,每次产生一个 batch 大小的数据,节省内存
Dataset 是 PyTorch 中用来表示数据集的一个抽象类,我们的数据集可以用这个类来表示,至少覆写下面两个方法即可
这返回数据前可以进行适当的数据处理,比如将原文用一串数字序列表示
下面我们来为编写一个类表示一个情感二分类数据集,继续用苏神整理的数据集
https://github.com/bojone/bert4keras/tree/master/examples/datasets
from torch.utils.data import Dataset, DataLoader import pandas as pd class SentimentDataset(Dataset): def __init__(self, path_to_file): self.dataset = pd.read_csv(path_to_file, sep="\t", names=["text", "label"]) def __len__(self): return len(self.dataset) def __getitem__(self, idx): text = self.dataset.loc[idx, "text"] label = self.dataset.loc[idx, "label"] sample = {"text": text, "label": label} return sample
Dataloader 就是一个生成器,传入一个 Dataset 对象,它就会根据参数 batch_size 的值生成一个 batch 的数据
if __name__ == "__main__": sentiment_dataset = SentimentDataset("sentiment.test.data") sentiment_dataloader = DataLoader(sentiment_dataset, batch_size=4, shuffle=True, num_workers=2) for idx, batch_samples in enumerate(sentiment_dataloader): text_batchs, text_labels = batch_samples["text"], batch_samples["label"] print(text_batchs)
PyTorch笔记之 Dataset 和 Dataloader
原文:https://www.cnblogs.com/dogecheng/p/11930535.html