首页 > 其他 > 详细

CLASStorch.utils.data.TensorDataset(*tensors)

时间:2021-06-21 20:21:03      阅读:21      评论:0      收藏:0      [点我收藏+]

CLASStorch.utils.data.TensorDataset(*tensors)

Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.

Parameters

*tensors (Tensor) – tensors that have the same size of the first dimension.

 

train_features=torch.Tensor([[1.1,2.1,3.1],[4.1,5.1,6.1],[7.1,8.1,9.1],[10.1,11.1,12.1]])

train_labels=torch.Tensor([[1.1],[2.1],[3.1],[4.1]])
dataset = torch.utils.data.TensorDataset(train_features,
train_labels)
print(dataset)
for i in dataset:
    print(i)
输出结果:

<torch.utils.data.dataset.TensorDataset object at 0x0000023D5A814B38>
(tensor([1.1000, 2.1000, 3.1000]), tensor([1.1000]))
(tensor([4.1000, 5.1000, 6.1000]), tensor([2.1000]))
(tensor([7.1000, 8.1000, 9.1000]), tensor([3.1000]))
(tensor([10.1000, 11.1000, 12.1000]), tensor([4.1000]))

该函数将行数或列数相同的俩个维度数组进行拼接,在这个代码中俩个数据按行包装。

batch_size=2
train_iter = torch.utils.data.DataLoader(dataset,batch_size,shuffle=True)
print(train_iter)
for X,y in train_iter:
    print(X)
    print(y)
<torch.utils.data.dataloader.DataLoader object at 0x0000024E1888B898>

tensor([[1.1000, 2.1000, 3.1000],
[4.1000, 5.1000, 6.1000]])
tensor([[1.1000],
[2.1000]])
tensor([[ 7.1000, 8.1000, 9.1000],
[10.1000, 11.1000, 12.1000]])
tensor([[3.1000],
[4.1000]])

CLASStorch.utils.data.DataLoader,如图可知DataLoader是将包装好的n*(特征向量+标签)分成n/x批,每批(x*特征向量,x*标签)。该函数返回的是一个迭代对象。batch_size是批大小。

CLASStorch.utils.data.TensorDataset(*tensors)

原文:https://www.cnblogs.com/hahaah/p/14914603.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!