Hi all,
I am trying to develop a simple network using the mx.symbol.linalg_gemm2
function. However I cannot get the training of the model to work due to the impossibility of incorrectly inferring the symbol shapes. I am using the R API.
The code follows:
NFEAT = 795
nDest 780
batchSz = 150
DL = customArrayIter(X.model, data.shape=c(NFEAT, 150), label=Y.model, batch.size=1)
X = mx.symbol.Variable('data')
A = mx.symbol.Variable('A')
B = mx.symbol.Variable('B')
CC = mx.symbol.linalg_gemm2(X, A, name='CC')
Yhat = mx.symbol.linalg_gemm2(B, CC, name='Yhat')
out = mx.symbol.SoftmaxActivation(Yhat, name='out')
loss = mx.symbol.LinearRegressionOutput(Yhat, name='loss')
To note that using the function mx.symbol.infer.shape
as follows:
shps=mx.symbol.infer.shape(loss, data=c(NFEAT, 150), A=c(nDest, NFEAT), B=c(150, 1))
The shapes appear to be correctly inferred.
Can someone advise me on hoe to specify the shapes in the mx.mdoel.FeedForward.create function?
Thanks a lot!