首页 > 其他 > 详细

Pytorch 第一个程序

时间:2021-07-29 22:26:56      阅读:33      评论:0      收藏:0      [点我收藏+]
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets  #这里指定当前数据集为torchvision
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

1. DataLoader

是Pytorch用来加载数据的常用的类,返回一个可遍历的数据集对象

传入参数:

  • dataset (Dataset) – dataset from which to load the data.

  • batch_size (intoptional) – how many samples per batch to load (default: 1).

  • shuffle (booloptional) – set to True to have the data reshuffled at every epoch (default: False)

2. torchvision

是一个包,里面包含了很多常用的视觉数据集。类似的还有torchtext, torchaudio,...

 

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

1. torchvision.datasets里的所有datasets(这里是FashionMNIST)都是torch.utils.data.Dataset的子类,因为这些子类都写了__getitem__和__len__,所以可以被传入torch.utils.data.DataLoader。

 2. FashionMNIST的属性有:

  • root (string) – Root directory of dataset where FashionMNIST/processed/training.pt and FashionMNIST/processed/test.pt exist.

  • train (booloptional) – If True, creates dataset from training.pt, otherwise from test.pt.

  • download (booloptional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

  • transform (callableoptional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

 

Pytorch 第一个程序

原文:https://www.cnblogs.com/qinzihao/p/15077164.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!