How to split symbol?

I have a symbol that is created by concat between data and seg at dim=1

data with shape of 64x3x224x224
seg with shpe of 64x1x224x224

Hence, the symbol has size of 64x4x224x224. I want to split the symbol back to data and seg. How can I do it? This is what I did

data_seg  = mx.symbol.split(data=data_seg, axis=1, num_outputs=2)
data = data_seg [0]
seg =  data_seg [1]

However, the shape of data and seg result in 64x2x224x224 instead of 64x3x224x224 and 64x1x224x224, respectively.

You probably want to use mx.symbol.slice()
data = mx.nd.slice(data_seg,begin=(None,0,None,None),end=(None,2,None,None))
seg = mx.nd.slice(data_seg,begin=(None,3,None,None),end=(None,4,None,None))

1 Like