I’d like to open discussion of the design of a simple reshaping layer that allows dimensions to be easily merged (by which I mean flattened together).
This would be the inverse of a layer extension that is currently under review, namely sub-setting support in reshape_like
: https://github.com/apache/incubator-mxnet/pull/11928. The extension implemented by that PR makes it very easy to split dimensions that have previously been merged. This proposal takes care of the inverse case: merging several dimensions together. This is a common operation when dealing with spatial or temporal sequences which we are applying the same operation to, where it is typical to fold these extra dimensions into the batch dimension, apply the operation, and the split that batch dimension back up again.
The proposed merge_dims
op would allow different dimensions to be merged. It would take as a parameter a list of which axis of the input to send each axis of the output. If multiple input axes get mapped to the same output axis, they get merged. This list must be the same length as the rank of the input.
So, for example, this table shows on the left the input shape, in the middle the target_axes parameter, and on the right the output shape that the layer would produce:
[2, 3, 5] -> [0, 0, 0] -> [30]
[2, 3, 5] -> [0, 0, 1] -> [6, 5]
[2, 3, 5] -> [0, 1, 1] -> [2, 15]
[2, 3, 5] -> [0, 1, 0] -> [10, 3]
[2, 3, 5] -> [1, 0, 1] -> [3, 10]
[2, 3, 5] -> [1, 0, 0] -> [15, 2]
[2, 3, 5] -> [0, 1, 2] -> [2, 3, 5]
[2, 3, 5] -> [2, 1, 0] -> [5, 3, 2]
Let’s explain in more detail one of these examples:
in_shape = [2,3,5]
target_axes = [0, 1, 1]
output_shape = [2, 15]
First, keep in mind that target_axes
uses zero indexing, so 0
refers to the first axis of the input, 1
refers to the second axis, etc.
The initial 0
of target_axes
means “send the first axis of the input to the first axis of the output”; therefore the first element of output_shape
is 2.
Next, the remaining 1, 1
of target_axes
means “send the second and third axis of the input to the second axis of the output”, therefore the second element of output_shape
is 3*5=15
.
This mechanism is quite flexible and allows for arbitrary re-ordering of the dimensions in addition to merging, which you can see in the final two rows of the table.
In Python, the exact semantics would be as follows:
def out_shape(in_shape, target_axes):
assert(len(in_shape) == len(target_axes))
out_shape = [1] * (1 + max(target_axes))
assert(min(target_axes) >= 0)
for (in_dim, target_axis) in zip(in_shape, target_axes):
out_shape[target_axis] *= in_dim
return out_shape
We will implement this if approved.
Edit: unwithdrawn.