Hi all,
I’ve identified a pattern of issues that has come up for us during the translation process between Mathematica’s high level net representations and MXNet symbols.
Consider a net that takes a matrix M of size (n, 5) and a vector V of size 5. Let’s say we wish to catenate these, by first broadcasting the vector V from size (5) to size (n, 5), so that it is compatible with the matrix M. Then we catenate to produce a matrix of size (n, 10).
Fundamentally, n here is a dynamic dimension: we wish to take the exact same MX symbol and via MXSymbolInferShape create new symbols for various values of n.
This is not possible currently as far as I know. While you could imagine a broadcast_catenate symbol that makes this possible, this is not the only place that this particular issue comes up.
Another example is when using ‘batch-flattening’ to map a FC layer over a sequence. In this example, say you have a sequence X of size (b, t, 3). Here, b is the batch dimension and t is the sequence(time) dimension, 3 is the feature count.
Then let’s say you have an FC layer that wants an input of size (b, 3). By using the reshape spec {0, -3, -2} code, we can use Reshape to obtain a version of X in which the batch and time dimensions have been flattened together. Now, this reshaped tensor X’ has dimensions (b * t, 3).
Next, we apply the FC layer as normal, to obtain, say, an output tensor Y with shape (b * t, 4). Now, we must reshape this to have dimension (b, t, 4). Again, there is no way of doing this such that the same MX symbol will work for all values of t via appropriate calls to MXSymbolInferShape.
We wish to do this via InferShape for two reasons: the compilation process from Mathematica’s high level networks to MXNet is expensive, and repeating it for different n is wasteful given that the resulting graphs are otherwise identical, and when deploying the net outside of our high-level framework, InferShape is a very easy way to create new graphs for a specific input length from e.g. C or C++ code. It is not possible to call Mathematica’s MXNet compiler in that case.
Does the MXNet community have any suggestions for how to handle this general class of issue? It seems like there is a missing operator, a more flexible version of reshape_like, that would allow us to use a second ‘shape target’ but take only specific dimensions from it to produce the target shape, rather than take all of the dimensions from it.
Thanks,
Tali