首页 > 其他 > 详细

pytorch中调用C进行扩展

时间:2019-11-14 16:07:56      阅读:153      评论:0      收藏:0      [点我收藏+]

pytorch中调用C进行扩展,使得某些功能在CPU上运行更快;

第一步:编写头文件

/* src/my_lib.h */
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

 

第二步:编写源文件

/* src/my_lib.c */
#include <TH/TH.h>

int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
    if (!THFloatTensor_isSameSizeAs(input1, input2))
        return 0;
    THFloatTensor_resizeAs(output, input1);
    THFloatTensor_cadd(output, input1, 1.0, input2);
    return 1;
}

int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
    THFloatTensor_resizeAs(grad_input, grad_output);
    THFloatTensor_fill(grad_input, 1);
    return 1;
}

 

注意:头文件TH就是pytorch底层代码的接口头文件,它是CPU模式,GPU下则为THC;

技术分享图片

 

 第三步:在同级目录下创建一个.py文件(比如叫“build.py”)

该文件用于对该C扩展模块进行编译(使用torch.util.ffi模块进行扩展编译);

# build.py
from torch.utils.ffi import create_extension
ffi = create_extension(
name=_ext.my_lib,        # 输出文件地址及名称
headers=src/my_lib.h,    # 编译.h文件地址及名称
sources=[src/my_lib.c],  # 编译.c文件地址及名称
with_cuda=False            # 不使用cuda
)
ffi.build()

 

第四步:编写.py脚本调用编译好的C扩展模块

import torch
from torch.autograd import Function
from _ext import my_lib
import torch.nn as nn

class MyAddFunction(Function):
    def forward(self, input1, input2):
        output = torch.FloatTensor()
        my_lib.my_lib_add_forward(input1, input2, output)
        return output

    def backward(self, grad_output):
        grad_input = torch.FloatTensor()
        my_lib.my_lib_add_backward(grad_input, grad_output)
        return grad_input

class MyAddModule(nn.Module):
    def forward(self, input1, input2):
        return MyAddFunction()(input1, input2)

class MyNetWork(nn.Module):
    def __init__(self):
        super(MyNetWork, self).__init__()
        self.add = MyAddModule()

    def forward(self, input1, input2):
        return self.add(input1, input2)

model = MyNetWork()
input1, input2 = torch.randn(5, 5), torch.randn(5, 5)
print(model(input1, input2))
print(input1 + input2)

 

至此,用这个简单的例子抛砖引玉~

 

pytorch中调用C进行扩展

原文:https://www.cnblogs.com/zf-blog/p/11857580.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!