self.params的使用注意点

2022-07-28

from mxnet.gluon import nn
from mxnet import nd
class MyDense(nn.HybridBlock):
    def __init__(self, units, in_units, **kwargs):
        super().__init__(**kwargs)

        self.embedding = nn.Embedding(3, 5)
        self.weight = self.params.get('weight', shape=(in_units, units))
        self.bias = self.params.get('bias', shape=(units,))
    def hybrid_forward(self, F, x, weight,bias):

        #linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data(ctx=x.ctx)
#         print(self.embedding)

        linear = F.dot(x, weight) + bias
        return F.relu(linear)

dense = MyDense(units=3,in_units=5)
dense.initialize()

print(dense(nd.random.uniform(shape=(2, 5))))
print('hybrid_forward success!')

from mxnet.gluon import nn
from mxnet import nd
class MyDense(nn.Block):
    def __init__(self, units, in_units, **kwargs):
        super().__init__(**kwargs)

        self.embedding = nn.Embedding(3, 5)
        self.weight = self.params.get('weight', shape=(in_units, units))
        self.bias = self.params.get('bias', shape=(units,))
    def forward(self, x, weight,bias):
        # 这种注释方式可以
        #linear = np.dot(x, self.weight.data(ctx=x.ctx)) + self.bias.data(ctx=x.ctx)
#         print(self.embedding)

        linear = nd.dot(x, weight) + bias
        return nd.relu(linear)

dense = MyDense(units=3,in_units=5)
dense.initialize()
try:
    print(dense(nd.random.uniform(shape=(2, 5))))
except Exception as e:
    print('some error:',e)

print('forward fail!')

 

结果:

[[0.         0.14012612 0.0058622 ]
 [0.         0.12333627 0.063691  ]]
<NDArray 2x3 @cpu(0)>
hybrid_forward success!
some error: forward() missing 2 required positional arguments: 'weight' and 'bias'
forward fail!
 

 

本文地址:https://blog.csdn.net/sinat_24395003/article/details/109644137

《self.params的使用注意点.doc》

下载本文的Word格式文档,以方便收藏与打印。