首页 > 其他 > 详细

pytorch入门到项目(八)数据模型的读取(2)

时间:2020-12-17 17:24:14      阅读:29      评论:0      收藏:0      [点我收藏+]
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
BASE_PATH = os.path.abspath(__file__)
# print(BASE_PATH)
base_path = os.path.abspath(os.path.join(BASE_PATH, ‘..‘, ‘TestDir‘))
# print(base_path)
data_dir = os.path.abspath(os.path.join(BASE_PATH, ‘..‘, ‘RMB_data‘))
random.seed(1)
# print(data_dir)
test_label = {"1": 0, "100": 1}
data_info = list()
for path, dirs, files in os.walk(base_path):
    for sub_dir in dirs:
        # print(sub_dir)
        sub_dirlist = os.listdir(os.path.join(base_path, sub_dir))
        pynames = list(filter(lambda y: y.endswith(‘.jpg‘), sub_dirlist))
        # print(pynames)
        # print(test_label[sub_dir])
        for pyname in pynames:
            datainfo_dir = os.path.join(base_path, sub_dir, pyname)
            t_label=test_label[sub_dir]
            t_label = int(t_label)
            data_info.append((datainfo_dir, t_label))
# print(data_info)
new_data_info = list()
for data_info_e in data_info:
    x_dir, x_label = data_info_e
    x_img = Image.open(x_dir).convert(‘RGB‘)
    ok_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
    ])
    x_img = ok_transform(x_img)
    new_data_info.append((x_img,x_label))

# print(len(new_data_info[0][0]))
print(len(new_data_info))
newdataLoader = DataLoader(new_data_info,batch_size=14, shuffle=True)
for ids, data in enumerate(newdataLoader):
    print(ids)

pytorch入门到项目(八)数据模型的读取(2)

原文:https://www.cnblogs.com/wuxero/p/14150447.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!