首页 > 其他 > 详细

pytorch训练多分支网络

时间:2021-04-09 09:46:20      阅读:135      评论:0      收藏:0      [点我收藏+]

 

How to train multi-branch output network?

技术分享图片

一个多分支输出网络(一个Encoder,多个Decoder)

我们期望每个分支的损失L_i分别对各自的参数进行优化,而共享参数部分采用各分支损失之和Sum(L_i)进行优化。

 

 

params = list(Encoder.parameters()) + list(Decoder1.parameters())                     + list(Decoder2.parameters()) + list(Decoder3.parameters())
optim = torch.optim.Adadelta(params, learning_rate)

encoder_output = Encoder(input)
optim.zero_grad()
loss_1, loss_2, loss_3 = 0, 0, 0
# split batch
if encoder_output[1_batch_ind].shape[0] != 0:
     output1 = Decoder1(encoder_output[1_batch_ind])
     loss_1 = loss_fn(output1, ground_truth[1_batch_ind])
                
if encoder_output[2_batch_ind].shape[0] != 0:
     output2 = Decoder2(encoder_output[2_batch_ind])
     loss_2 = loss_fn(output2, ground_truth[2_batch_ind])
                
if encoder_output[3_batch_ind].shape[0] != 0:
     output3 = Decoder3(encoder_output[3_batch_ind])
     loss_3 = loss_fn(output3, ground_truth[3_batch_ind])
loss = loss_1 + loss_2 + loss_3
loss.backward()
optim.step() 

  

 

## you can simply do:

o1, o2 = mm(input)
o = o1 + o2
# loss

## Or you can do

l1 = loss(o1, target)
l2 = loss2(o2, target2)
torch.autograd.backward([l1, l2])

 

如果想不同的分支采用不同的优化器:

opt1 = optim.Adam(branch_1.parameters(), ...)
opt2 = optim.SGD(branch_2.parameters(), ...)
...
...
loss = 2*loss_1 + 3 *loss_2
loss.backward()
opt1.step()
opt2.step()

  

 

pytorch训练多分支网络

原文:https://www.cnblogs.com/jiangkejie/p/14635297.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!