import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets #这里指定当前数据集为torchvision
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
1. DataLoader
是Pytorch用来加载数据的常用的类,返回一个可遍历的数据集对象
传入参数:
dataset (Dataset) – dataset from which to load the data.
batch_size (int, optional) – how many samples per batch to load (default: 1
).
shuffle (bool, optional) – set to True
to have the data reshuffled at every epoch (default: False
)
2. torchvision
是一个包,里面包含了很多常用的视觉数据集。类似的还有torchtext, torchaudio,...
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
1. torchvision.datasets里的所有datasets(这里是FashionMNIST)都是torch.utils.data.Dataset的子类,因为这些子类都写了__getitem__和__len__,所以可以被传入torch.utils.data.DataLoader。
2. FashionMNIST的属性有:
root (string) – Root directory of dataset where FashionMNIST/processed/training.pt
and FashionMNIST/processed/test.pt
exist.
train (bool, optional) – If True, creates dataset from training.pt
, otherwise from test.pt
.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
原文:https://www.cnblogs.com/qinzihao/p/15077164.html