首页 > 其他 > 详细

Variable类

时间:2021-09-02 21:11:31      阅读:27      评论:0      收藏:0      [点我收藏+]

Variable类

Variable类在pytorch中定义,以下代码演示Variable的用法:

import torch
from torch.autograd import Variable
import numpy

tensor = torch.FloatTensor([[1, 2], [3, 4]])
var = Variable(tensor, requires_grad=True)
print(‘original matrix:‘, tensor)
print(‘\nVariable:‘, var)

t_out = torch.mean(tensor*tensor)
v_out = torch.mean(var*var)
print(‘\norigin squared and mean:‘, t_out)
print(‘\nVariable squared and mean:‘, v_out)

v_out.backward()
print(‘\ngradient after backward‘, var.grad)

print(‘\nvar.data is a tensor:‘, var.data)

print(‘\nthe tensor can become a numpy:‘, var.data.numpy())

输出结果:

original matrix: tensor([[1., 2.],
        [3., 4.]])

Variable: tensor([[1., 2.],
        [3., 4.]], requires_grad=True)

origin squared and mean: tensor(7.5000)

Variable squared and mean: tensor(7.5000, grad_fn=<MeanBackward0>)

gradient after backward tensor([[0.5000, 1.0000],
        [1.5000, 2.0000]])

var.data is a tensor: tensor([[1., 2.],
        [3., 4.]])

the tensor can become a numpy: [[1. 2.]
 [3. 4.]]

Variable对象可以进行与tensor相同的运算。

同时能够记录梯度。

它的一个属性data是tensor类型。

Variable类

原文:https://www.cnblogs.com/cxf-tech/p/15219829.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!