Here is a solution that does what you want, I added a bit of parametrization to make it more flexible, but still it is quite probably a bit slow but should be fine for smaller sized arrays.
import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon import nn
class OverlapAdd(nn.HybridBlock):
def __init__(self, overlap=2, axis=2, axis_size=5, **kwargs):
"""
Parameters
----------
overlap : int
How many rows the two arrays should overlap
axis : int
On which dimension the overlap should happen
axis_size : int
The size of the arrays on the overlap dimension
"""
super(OverlapAdd, self).__init__(**kwargs)
self.axis = axis
self.axis_size = axis_size
self.overlap = overlap
def hybrid_forward(self, F, a, b):
a_0 = F.split(a, self.axis_size, axis=self.axis)
b_0 = F.split(b, self.axis_size, axis=self.axis)
mixed = []
for i in range(self.overlap):
mixed.append(a_0[self.axis_size-i-1] + b_0[i])
a_1 = F.concat(*a_0[:self.axis_size-self.overlap], dim=self.axis)
b_1 = F.concat(*b_0[self.overlap:], dim=self.axis)
mixed = F.concat(*mixed, dim=self.axis)
output = F.concat(a_1, mixed, b_1, dim=self.axis)
return output
net = OverlapAdd()
net.hybridize()
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))
net(a,b)
output
[[[[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1.]]]]
<NDArray 1x1x8x8 @cpu(0)>
It also works on other dimensions with different overlap sizes
net = OverlapAdd(axis=3, axis_size=8, overlap=4)
net.hybridize()
a = mx.nd.ones(shape=(1,1,5,8))
b = mx.nd.ones(shape=(1,1,5,8))
print(net(a,b))
output
[[[[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]
[1. 1. 1. 1. 2. 2. 2. 2. 1. 1. 1. 1.]]]]
<NDArray 1x1x5x12 @cpu(0)>
Symbol file of model:
{
"nodes": [
{
"op": "null",
"name": "data0",
"inputs": []
},
{
"op": "SliceChannel",
"name": "overlapadd2_split0",
"attrs": {
"axis": "2",
"num_outputs": "5"
},
"inputs": [[0, 0, 0]]
},
{
"op": "Concat",
"name": "overlapadd2_concat0",
"attrs": {
"dim": "2",
"num_args": "3"
},
"inputs": [[1, 0, 0], [1, 1, 0], [1, 2, 0]]
},
{
"op": "null",
"name": "data1",
"inputs": []
},
{
"op": "SliceChannel",
"name": "overlapadd2_split1",
"attrs": {
"axis": "2",
"num_outputs": "5"
},
"inputs": [[3, 0, 0]]
},
{
"op": "elemwise_add",
"name": "overlapadd2__plus0",
"inputs": [[1, 4, 0], [4, 0, 0]]
},
{
"op": "elemwise_add",
"name": "overlapadd2__plus1",
"inputs": [[1, 3, 0], [4, 1, 0]]
},
{
"op": "Concat",
"name": "overlapadd2_concat2",
"attrs": {
"dim": "2",
"num_args": "2"
},
"inputs": [[5, 0, 0], [6, 0, 0]]
},
{
"op": "Concat",
"name": "overlapadd2_concat1",
"attrs": {
"dim": "2",
"num_args": "3"
},
"inputs": [[4, 2, 0], [4, 3, 0], [4, 4, 0]]
},
{
"op": "Concat",
"name": "overlapadd2_concat3",
"attrs": {
"dim": "2",
"num_args": "3"
},
"inputs": [[2, 0, 0], [7, 0, 0], [8, 0, 0]]
}
],
"arg_nodes": [0, 3],
"node_row_ptr": [
0,
1,
6,
7,
8,
13,
14,
15,
16,
17,
18
],
"heads": [[9, 0, 0]],
"attrs": {"mxnet_version": ["int", 10401]}
}