It’s hard to tell what exactly you are trying to achieve without any code.
However, I work with a network that needs to propagate feature maps and a certain error through the graph, so it might be a bit similiar. Here are some code snippets:
class HybridDualSequential(nn.HybridSequential):
"""Stacks HybridBlocks with 2 outputs and inputs sequentially."""
def __init__(self, prefix=None, params=None):
super(HybridDualSequential, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, x, stacked_w=None):
y = stacked_w
error_sum = F.zeros(shape=(1,))
for block in self._children.values():
x, y = block.hybrid_forward(F, x, stacked_w=y)
error_sum = error_sum + y
return x, error_sum
This is similiar to SequentialBlock as used in the ResNet code, only difference is it passes two objects on. Now you can build the network with hybrid_forward functions that take these as inputs, like so:
class BasicBlockV1(HybridBlock):
def __init__(self, channels, stride, downsample=False, in_channels=0, clip_threshold=1.0,
prefix='', **kwargs):
super(BasicBlockV1, self).__init__(prefix=prefix, **kwargs)
with self.name_scope():
self.conv = nn.Conv2D(channels[0], kernel_size=7, strides=2, padding=3, use_bias=False)
self.fixed_data = nd.array([0,0,1])
if downsample:
self.downsample = nn.HybridSequential(prefix=prefix)
self.downsample.add(nn.AvgPool2D(pool_size=2, strides=2, padding=0))
self.downsample.add(
nn.Conv2D(channels, kernel_size=1, strides=1, in_channels=in_channels,
use_bias=False, prefix="sc_qconv_"))
else:
self.downsample = None
def hybrid_forward(self, F, x, stacked_w):
# Now you can use x and stacked_w here, if you only need them in the last layer, just pass stacked_w on until you have it in the layer where you need to do something with it
# or you use the class variables like self.fixed_data, if your data does not change with the batch size
return self.conv(x), stacked_w + self.fixed_data
Then you can create a model by adding these blocks to the HybridDualSequential:
seq = HybridDualSequential()
for i in range(10):
seq.add(BasicBlockV1(64, 2)
Then you can call seq with your data seq(a) or seq(a,b)
Hope that helps.