初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录
代码:
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
输出
上面的代码加一行
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call 函数', param)
a = A()
a(1)
输出
_ call_()中可以调用其它函数,如forward函数
class A(): def __init__(self): print('init函数') def __call__(self, param): print('call 函数', param) res = self.forward(param) return res 2 def forward(self, input_): print('forward 函数', input_) return input_ a = A() b = a(1) print('结果b =',b)
看了上面的例子,就知道了_call _()的作用,那下面看更接近CNN的例子
from torch import nn import torch class Ding(nn.Module): def __init__(self): print('init') super().__init__() def forward(self, input): output = input 1 print('forward') return output dzy = Ding() x = torch.tensor(1.0) out = dzy(x) print(out)
结果:
这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward()
当然,也可以重写__call__(),比如我们不让它使用forward()
from torch import nn import torch class Ding(nn.Module): def __init__(self): print('init') super().__init__() def __call__(self, input_): print('重写call, 不用forward') return 'hhh' def forward(self, input): output = input 1 print('forward') return output dzy = Ding() x = torch.tensor(1.0) out = dzy(x) print(out)
使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。
https://blog.csdn.net/dss_dssssd/article/details/83750838
https://zhuanlan.zhihu.com/p/366461413
联系客服