首页 > 其他 > 详细

关于torchvision.datasets.CIFAR10

时间:2019-11-10 21:44:12      阅读:684      评论:0      收藏:0      [点我收藏+]

在Pytorch0.4版本的DARTS代码里,有一行代码是

trn_data = datasets.CIFAR10(root=data_path, train=True, download=False, transform=train_transform)
shape = trn_data.train_data.shape

在1.2及以上版本里,查看源码可知,CIFAR10这个类已经没有train_data这个属性了,取而代之的是data,因此要把第二行改成

shape = trn_data.data.shape

datasets.CIFAR10源码如下:

from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys

if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive


[docs]class CIFAR10(VisionDataset):
    """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory
            ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
        train (bool, optional): If True, creates dataset from training set, otherwise
            creates from test set.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.

    """
    base_folder = cifar-10-batches-py
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = c58f30108f718f92721af3b95e74349a
    train_list = [
        [data_batch_1, c99cafc152244af753f735de768cd75f],
        [data_batch_2, d4bba439e000b95fd0a9bffe97cbabec],
        [data_batch_3, 54ebc095f3ab1f0389bbae665268c751],
        [data_batch_4, 634d18415352ddfa80567beed471001a],
        [data_batch_5, 482c414d41f54cd18b22e5b47cb7c3cb],
    ]

    test_list = [
        [test_batch, 40351d587109b95175f43aff81a1287e],
    ]
    meta = {
        filename: batches.meta,
        key: label_names,
        md5: 5ff9c542aee3614f3951f8cda6e48888,
    }

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False):

        super(CIFAR10, self).__init__(root, transform=transform,
                                      target_transform=target_transform)

        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(Dataset not found or corrupted. +
                                You can use download=True to download it)

        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, rb) as f:
                if sys.version_info[0] == 2:
                    entry = pickle.load(f)
                else:
                    entry = pickle.load(f, encoding=latin1)
                self.data.append(entry[data])
                if labels in entry:
                    self.targets.extend(entry[labels])
                else:
                    self.targets.extend(entry[fine_labels])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

        self._load_meta()

 

关于torchvision.datasets.CIFAR10

原文:https://www.cnblogs.com/yqpy/p/11831717.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!