How to set symbol connections by indexing

Imagine the following scenario:
Network A output: a symbol with shape [1,1,5,8]
Network B output: a symbol with shape [1,1,5,8]
Network C accepts input with shape [1,1,8,8]

What I need to do is to merge two symbols of network A and B in the following way:

VAR = zeros(1,1,8,8)
VAR[0:1, 0:1, 0:5, :] = VAR[0:1 ,0:1, 0:5, :] + Net_A_Output
VAR[0:1, 0:1, 3:8, :] = VAR[0:1, 0:1, 3:8, :] + Net_B_Output

and VAR is used as input to network C.

Now the question is how can I do this in symbolic API to have all three networks as a single graph. I need to update the whole graph based on a single loss.

In other words, there are functions such as “slice” which returns a smaller shape symbol by indexing other symbols but is there any function to do it the other way. To set some index of a symbol from another symbol?

One way, not the fastest or most elegant but it should work and you should get the correct gradients:

split output of A and B on the 3rd dimension, add last of A with first of B, concat them all back together on the 3rd dimension.

# data
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))

# split
a_ = mx.nd.split(a, 5, axis=2)
b_ = mx.nd.split(b, 5, axis=2)

# mix
mixed_1 = a_[-2] + b_[0]
mixed_2 = a_[-1] + b_[1]

# concat
final = a_[:-2]+[mixed_1, mixed_2]+b_[2:]
output = mx.nd.concat(*final, dim=2)

print(output)
[[[[1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [2. 2. 2. 2. 2. 2. 2. 2.]
   [2. 2. 2. 2. 2. 2. 2. 2.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]]]]
<NDArray 1x1x8x8 @cpu(0)>

Thanks @ThomasDelteil
The problem still exists with symbol variable, since in the #mix section the nd arrays are accessed by indexing.

# mix
mixed_1 = a_[-2] + b_[0]
mixed_2 = a_[-1] + b_[1]

but if mixed_1 and mixed_2 are symbol variables, (mx.sym.var) there is no obvious method to access them with indexing. in other words, this script will raise an error for symbolic variables.

“They are python arrays of symbols, not symbols, after split so the problem doesn’t exist I believe because you are indexing the array not the symbol”

EDIT: you are right, I thought Symbolic split worked as NDArray split as in returning an array of symbol but it does return a symbol instead. Let me see if I can find an alternative.

Here is a solution that does what you want, I added a bit of parametrization to make it more flexible, but still it is quite probably a bit slow but should be fine for smaller sized arrays.

import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn

class OverlapAdd(nn.HybridBlock):
    
    def __init__(self, overlap=2, axis=2, axis_size=5, **kwargs):
        """
        Parameters
        ----------
        overlap : int
            How many rows the two arrays should overlap
        axis : int
            On which dimension the overlap should happen
        axis_size : int
            The size of the arrays on the overlap dimension
        """
        super(OverlapAdd, self).__init__(**kwargs)
        self.axis = axis
        self.axis_size = axis_size
        self.overlap = overlap

    def hybrid_forward(self, F, a, b):
        
        
        a_0 = F.split(a, self.axis_size, axis=self.axis)
        b_0 = F.split(b, self.axis_size, axis=self.axis)
        mixed = []
        for i in range(self.overlap):
            mixed.append(a_0[self.axis_size-i-1] + b_0[i])
        a_1 = F.concat(*a_0[:self.axis_size-self.overlap], dim=self.axis)
        b_1 = F.concat(*b_0[self.overlap:], dim=self.axis)
        mixed = F.concat(*mixed, dim=self.axis)
        output = F.concat(a_1, mixed, b_1, dim=self.axis)
        return output

    
net = OverlapAdd()
net.hybridize()
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))

net(a,b)

output

[[[[1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [2. 2. 2. 2. 2. 2. 2. 2.]
   [2. 2. 2. 2. 2. 2. 2. 2.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]
   [1. 1. 1. 1. 1. 1. 1. 1.]]]]
<NDArray 1x1x8x8 @cpu(0)>

It also works on other dimensions with different overlap sizes

net = OverlapAdd(axis=3, axis_size=8, overlap=4)
net.hybridize()
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))
print(net(a,b))

output

[[[[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
   [1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
   [1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
   [1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
   [1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]]]]
<NDArray 1x1x5x12 @cpu(0)>

Symbol file of model:

{
  "nodes": [
    {
      "op": "null", 
      "name": "data0", 
      "inputs": []
    }, 
    {
      "op": "SliceChannel", 
      "name": "overlapadd2_split0", 
      "attrs": {
        "axis": "2", 
        "num_outputs": "5"
      }, 
      "inputs": [[0, 0, 0]]
    }, 
    {
      "op": "Concat", 
      "name": "overlapadd2_concat0", 
      "attrs": {
        "dim": "2", 
        "num_args": "3"
      }, 
      "inputs": [[1, 0, 0], [1, 1, 0], [1, 2, 0]]
    }, 
    {
      "op": "null", 
      "name": "data1", 
      "inputs": []
    }, 
    {
      "op": "SliceChannel", 
      "name": "overlapadd2_split1", 
      "attrs": {
        "axis": "2", 
        "num_outputs": "5"
      }, 
      "inputs": [[3, 0, 0]]
    }, 
    {
      "op": "elemwise_add", 
      "name": "overlapadd2__plus0", 
      "inputs": [[1, 4, 0], [4, 0, 0]]
    }, 
    {
      "op": "elemwise_add", 
      "name": "overlapadd2__plus1", 
      "inputs": [[1, 3, 0], [4, 1, 0]]
    }, 
    {
      "op": "Concat", 
      "name": "overlapadd2_concat2", 
      "attrs": {
        "dim": "2", 
        "num_args": "2"
      }, 
      "inputs": [[5, 0, 0], [6, 0, 0]]
    }, 
    {
      "op": "Concat", 
      "name": "overlapadd2_concat1", 
      "attrs": {
        "dim": "2", 
        "num_args": "3"
      }, 
      "inputs": [[4, 2, 0], [4, 3, 0], [4, 4, 0]]
    }, 
    {
      "op": "Concat", 
      "name": "overlapadd2_concat3", 
      "attrs": {
        "dim": "2", 
        "num_args": "3"
      }, 
      "inputs": [[2, 0, 0], [7, 0, 0], [8, 0, 0]]
    }
  ], 
  "arg_nodes": [0, 3], 
  "node_row_ptr": [
    0, 
    1, 
    6, 
    7, 
    8, 
    13, 
    14, 
    15, 
    16, 
    17, 
    18
  ], 
  "heads": [[9, 0, 0]], 
  "attrs": {"mxnet_version": ["int", 10401]}
}
1 Like

Thanks @ThomasDelteil