Say I have a network define with mxnet.symbol.
And my loss is depended on the batch size of the feature map.
For example:
x = get_network() # x is a symbol
batch_size = get_batch_size() # THE QUESTION
diag_mask = mx.symbol.eye(batch_size)
loss = get_loss(x, diag_mask)
My dataset size cannot be divisible by batchsize, so when the program use the last batch of my dataset, an error would be arise that dimension mismatch.
How can I get the batch size like tf.shape(x)
in tensorflow?