博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
torch中实现自定义层
阅读量:6696 次
发布时间:2019-06-25

本文共 3250 字,大约阅读时间需要 10 分钟。

当我们要实现自己的一些idea时,torch自带的模块和函数已经不能满足,我们需要自己实现层(或者类),一般的做法是把自定义层加入到已有的torch模块中。

torch中实现自定义层

lua实现

如果自定义层的功能可以通过调用torch中已有的函数实现,那就只需要用lua实现,torch的文档中也提供了简单的。

现在我们来实现一个NewClass:

  • 在torch目录下(torch/extra/nn/)创建文件NewClass.lua

  • 参考nn中其他lua文件的结构写好模板,在对应的函数中实现想要的功能

--创建新类,从nn.Module继承local NewClass, Parent = torch.class('nn.NewClass', 'nn.Module')--初始化操作function NewClass:__init()   Parent.__init(self)end--前向传播function NewClass:updateOutput(input)end--反向传播function NewClass:updateGradInput(input, gradOutput)end--损失对参数的偏导,也就是残差,如果该层没有要学习的参数,则不需要写这个函数function NewClass:accGradParameters(input, gradOutput)end
  • 在nn的init.lua中末尾添加一句

require('nn.NewClass')
  • 重新安装nn模块

cd torch/extra/nn/luarocks make rocks/nn-scm-1.rockspec
  • 安装成功后,在自己的代码中使用自定义的类了

require 'nn'...nn.NewClass()...

CPU实现

如果通过torch的函数不能实现出需要的功能,那么需要自己写C程序实现核心功能,然后在NewClass.lua中调用。

  • torch/extra/nn/lib/THNN/generic/目录下创建文件NewClass.c

  • 参考nn中已有的实现,在函数中实现需要的功能

...void THNN_(NewClass_updateOutput)(          THNNState *state,          THTensor *input,          THTensor *output){}void THNN_(NewClass_updateGradInput)(          THNNState *state,          THTensor *input,          THTensor *gradOutput,          THTensor *gradInput){}...
  • 声明已实现的函数,在torch/extra/nn/lib/THNN/generic/THNN.h中添加

...TH_API void THNN_(NewClass_updateOutput)(          THNNState *state,                      THTensor *input,                       THTensor *output);           TH_API void THNN_(NewClass_updateGradInput)(          THNNState *state,                      THTensor *input,                      THTensor *gradOutput,                THTensor *gradInput);    ...
  • 添加include,在torch/extra/nn/lib/THNN/init.c中,添加

#include "generic/SpatialConvolution.c"#include "THGenerateFloatTypes.h"
  • 在NewClass.lua中调用CPU版本的函数

...function NewClass:updateOutput(input)   input.THNN.NewClass_updateOutput(      input:cdata(),      self.output:cdata()   )   return self.outputendfunction NewClass:updateGradInput(input, gradOutput)   if self.gradInput then      input.THNN.NewClass_updateGradInput(         input:cdata(),         self.gradInput:cdata(),         gradOutput:cdata()      )      return self.gradInput   endend...
  • 重新编译安装nn

cd torch/extra/nn/luarocks make rocks/nn-scm-1.rockspec
  • 安装成功后,在自己的代码中使用自定义的类了

require 'nn'...nn.NewClass()...

Cuda实现

如果想要进一步提升运算效率,需要自己写一个Cuda版本的程序。

  • torch/extra/cunn/lib/THCUNN/目录下创建文件NewClass.cu

  • 参考cunn中已有的函数,实现函数功能

...void THNN_CudaNewClass_updateOutput(THCState *state, THCudaTensor *input, THCudaTensor *output){}void THNN_CudaNewClass_updateGradInput(THCState *state, THCudaTensor *input, THCudaTensor *gradOutput, THCudaTensor *gradInput){}...
  • 声明函数,在torch/extra/cunn/lib/THCUNN/THCUNN.h中添加:

TH_API void THNN_CudaNewClass_updateOutput(          THCState *state,          THCudaTensor *input,          THCudaTensor *output);TH_API void THNN_CudaNewClass_updateGradInput(          THCState *state,          THCudaTensor *input,          THCudaTensor *gradOutput,          THCudaTensor *gradInput);
  • 在NewClass.lua中调用GPU版本的函数,和CPU版本一样,都通过THNN调用

  • 重新编译安装cunn

cd torch/extra/cunn/luarocks make rocks/cunn-scm-1.rockspec
  • 安装成功后,在自己的代码中使用自定义的类了

require 'cunn'...nn.NewClass()...

测试

torch/extra/nn/test.luatorch/extra/cunn/test.lua中添加测试代码,可以用来测试NewClass的输出是否正确,具体可参考已有的测试代码。

添加好后,运行th -lnn -e "nn.test{'NewClass'}"即可测试。

转载地址:http://kwtoo.baihongyu.com/

你可能感兴趣的文章
java 电子商务云平台b2b b2c o2o springmvc+mybatis+spring cloud+spring boot
查看>>
如何通过ad组策略让domain users用户可以远程桌面?
查看>>
线程池的使用
查看>>
vb的winio模拟键盘鼠标部分参考代码
查看>>
等待多个并发事件完成的模型
查看>>
如何使用 PyCharm+Docker 打造深度学习利器
查看>>
十大压力测试工具,收下
查看>>
Maven学习总结(八)——使用Maven构建多模块项目
查看>>
易宝典文章——怎样管理Exchange Server 2013邮箱邮件流功能之传递选项
查看>>
Interested Transaction List ( ITL ) in Oracle
查看>>
Centos 6.3 install Darwin Streaming Server 6.0.3
查看>>
个人博客的推广
查看>>
VUE页面渲染问题
查看>>
浮点型
查看>>
81.node.js前端html时页面格式错乱解决办法
查看>>
this与super关键字
查看>>
Word 2010 插入其他文件的方法
查看>>
BZOJ4766: 文艺计算姬(Prufer序列)
查看>>
ECMAScript 5 —— 单体内置对象之Global对象
查看>>
AGC 018E.Sightseeing Plan——网格路径问题观止
查看>>