# How to set symbol connections by indexing

Imagine the following scenario:
Network A output: a symbol with shape [1,1,5,8]
Network B output: a symbol with shape [1,1,5,8]
Network C accepts input with shape [1,1,8,8]

What I need to do is to merge two symbols of network A and B in the following way:

``````VAR = zeros(1,1,8,8)
VAR[0:1, 0:1, 0:5, :] = VAR[0:1 ,0:1, 0:5, :] + Net_A_Output
VAR[0:1, 0:1, 3:8, :] = VAR[0:1, 0:1, 3:8, :] + Net_B_Output
``````

and VAR is used as input to network C.

Now the question is how can I do this in symbolic API to have all three networks as a single graph. I need to update the whole graph based on a single loss.

In other words, there are functions such as â€śsliceâ€ť which returns a smaller shape symbol by indexing other symbols but is there any function to do it the other way. To set some index of a symbol from another symbol?

One way, not the fastest or most elegant but it should work and you should get the correct gradients:

split output of A and B on the 3rd dimension, add last of A with first of B, concat them all back together on the 3rd dimension.

``````# data
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))

# split
a_ = mx.nd.split(a, 5, axis=2)
b_ = mx.nd.split(b, 5, axis=2)

# mix
mixed_1 = a_[-2] + b_[0]
mixed_2 = a_[-1] + b_[1]

# concat
final = a_[:-2]+[mixed_1, mixed_2]+b_[2:]
output = mx.nd.concat(*final, dim=2)

print(output)
``````
``````[[[[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]]]]
<NDArray 1x1x8x8 @cpu(0)>
``````

Thanks @ThomasDelteil
The problem still exists with symbol variable, since in the #mix section the nd arrays are accessed by indexing.

``````# mix
mixed_1 = a_[-2] + b_[0]
mixed_2 = a_[-1] + b_[1]
``````

but if` mixed_1` and `mixed_2` are symbol variables, `(mx.sym.var)` there is no obvious method to access them with indexing. in other words, this script will raise an error for symbolic variables.

â€śThey are python arrays of symbols, not symbols, after split so the problem doesnâ€™t exist I believe because you are indexing the array not the symbolâ€ť

EDIT: you are right, I thought Symbolic split worked as NDArray split as in returning an array of symbol but it does return a symbol instead. Let me see if I can find an alternative.

Here is a solution that does what you want, I added a bit of parametrization to make it more flexible, but still it is quite probably a bit slow but should be fine for smaller sized arrays.

``````import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn

def __init__(self, overlap=2, axis=2, axis_size=5, **kwargs):
"""
Parameters
----------
overlap : int
How many rows the two arrays should overlap
axis : int
On which dimension the overlap should happen
axis_size : int
The size of the arrays on the overlap dimension
"""
self.axis = axis
self.axis_size = axis_size
self.overlap = overlap

def hybrid_forward(self, F, a, b):

a_0 = F.split(a, self.axis_size, axis=self.axis)
b_0 = F.split(b, self.axis_size, axis=self.axis)
mixed = []
for i in range(self.overlap):
mixed.append(a_0[self.axis_size-i-1] + b_0[i])
a_1 = F.concat(*a_0[:self.axis_size-self.overlap], dim=self.axis)
b_1 = F.concat(*b_0[self.overlap:], dim=self.axis)
mixed = F.concat(*mixed, dim=self.axis)
output = F.concat(a_1, mixed, b_1, dim=self.axis)
return output

net.hybridize()
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))

net(a,b)
``````

output

``````[[[[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]]]]
<NDArray 1x1x8x8 @cpu(0)>
``````

It also works on other dimensions with different overlap sizes

``````net = OverlapAdd(axis=3, axis_size=8, overlap=4)
net.hybridize()
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))
print(net(a,b))
``````

output

``````[[[[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]]]]
<NDArray 1x1x5x12 @cpu(0)>
``````

Symbol file of model:

``````{
"nodes": [
{
"op": "null",
"name": "data0",
"inputs": []
},
{
"op": "SliceChannel",
"attrs": {
"axis": "2",
"num_outputs": "5"
},
"inputs": [[0, 0, 0]]
},
{
"op": "Concat",
"attrs": {
"dim": "2",
"num_args": "3"
},
"inputs": [[1, 0, 0], [1, 1, 0], [1, 2, 0]]
},
{
"op": "null",
"name": "data1",
"inputs": []
},
{
"op": "SliceChannel",
"attrs": {
"axis": "2",
"num_outputs": "5"
},
"inputs": [[3, 0, 0]]
},
{
"inputs": [[1, 4, 0], [4, 0, 0]]
},
{
"inputs": [[1, 3, 0], [4, 1, 0]]
},
{
"op": "Concat",
"attrs": {
"dim": "2",
"num_args": "2"
},
"inputs": [[5, 0, 0], [6, 0, 0]]
},
{
"op": "Concat",
"attrs": {
"dim": "2",
"num_args": "3"
},
"inputs": [[4, 2, 0], [4, 3, 0], [4, 4, 0]]
},
{
"op": "Concat",
"attrs": {
"dim": "2",
"num_args": "3"
},
"inputs": [[2, 0, 0], [7, 0, 0], [8, 0, 0]]
}
],
"arg_nodes": [0, 3],
"node_row_ptr": [
0,
1,
6,
7,
8,
13,
14,
15,
16,
17,
18
],