首页 > 编程语言 > 详细

(十)pytorch多线程训练,DataLoader的num_works参数设置

时间:2021-07-30 16:16:02      阅读:112      评论:0      收藏:0      [点我收藏+]

一、概述

数据集较小时(小于2W)建议num_works不用管默认就行,因为用了反而比没用慢。
当数据集较大时建议采用,num_works一般设置为(CPU线程数+-1)为最佳,可以用以下代码找出最佳num_works(注意windows用户如果要使用多核多线程必须把训练放在if __name__ == ‘__main__‘:下才不会报错)

二、代码

import time
import torch.utils.data as d
import torchvision
import torchvision.transforms as transforms
 
 
if __name__ == __main__:
    BATCH_SIZE = 100
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5,), (0.5,))])
    train_set = torchvision.datasets.MNIST(\mnist, download=False, train=True, transform=transform)
    
    # data loaders
    train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    
    for num_workers in range(20):
        train_loader = d.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers)
        # training ...
        start = time.time()
        for epoch in range(1):
            for step, (batch_x, batch_y) in enumerate(train_loader):
                pass
        end = time.time()
        print(num_workers is {} and it took {} seconds.format(num_workers, end - start))

 三、查看线程数

1、cpu个数

grep physical id /proc/cpuinfo | sort -u

2、核心数

grep core id /proc/cpuinfo | sort -u | wc -l

3、线程数

grep processor /proc/cpuinfo | sort -u | wc -l

4、例子

命令执行结果如图所示,根据结果得知,此服务器有1个cpu,6个核心,每个核心2线程,共12线程。

技术分享图片

 

(十)pytorch多线程训练,DataLoader的num_works参数设置

原文:https://www.cnblogs.com/zhangxianrong/p/15079209.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!