Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler
torch的这个文件包含了一些关于数据集处理的类:
下面Pytorch提供的划分数据集的方法以示例的方式给出:
dataset = MyCustomDataset(my_path) batch_size = 16 validation_split = .2 shuffle_dataset = True random_seed= 42 # Creating data indices for training and validation splits: dataset_size = len(dataset) indices = list(range(dataset_size)) split = int(np.floor(validation_split * dataset_size)) if shuffle_dataset : np.random.seed(random_seed) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] # Creating PT data samplers and loaders: train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(val_indices) train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler) # Usage Example: num_epochs = 10 for epoch in range(num_epochs): # Train: for batch_index, (faces, labels) in enumerate(train_loader):
train_size = int(0.8 * len(full_dataset)) test_size = len(full_dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
参考:
https://www.cnblogs.com/marsggbo/p/10496696.html
https://likewind.top/2019/02/01/Pytorch-dataprocess/
https://blog.csdn.net/xholes/article/details/81410834
Pytorch划分数据集的方法:torch.utils.data.Subset
原文:https://www.cnblogs.com/Bella2017/p/11791216.html