Please help with this error, the code is below
def net(X):
num_hidden = 256
weight_scale = .01
input_layer = X
num_outputs = 2
W_hat = nd.random_normal(shape=(int(input_layer.shape[-1]), num_outputs), ctx=model_ctx)
M_hat = nd.random_normal(shape=(int(input_layer.shape[-1]), num_outputs), ctx=model_ctx)
G = nd.random_normal( shape=(int(input_layer.shape[-1]), num_outputs), ctx=model_ctx)
params = [W_hat, M_hat, G]
W1 = nd.tanh(W_hat)
W2 = nd.sigmoid(M_hat)
W = W1*W2
a = nd.dot(X,W)
g = nd.dot(X,G)
g = nd.sigmoid(g)
z1 = nd.abs(X)
z = nd.log(z1 + 1e-7)
m = nd.linalg_gemm2(z,W)
y = (g*a) + (1-g)*m
return y
epochs = 10
learning_rate = .001
smoothing_constant = .01
for e in range(epochs):
cumulative_loss = 0
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(model_ctx).reshape((-1, 784))
label = label.as_in_context(model_ctx)
label_one_hot = nd.one_hot(label, 10)
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label_one_hot)
loss.backward()
SGD(params, learning_rate)
cumulative_loss += nd.sum(loss).asscalar()
test_accuracy = evaluate_accuracy(test_data, net)
train_accuracy = evaluate_accuracy(train_data, net)
print("Epoch %s. Loss: %s, Train_acc %s, Test_acc %s" %
(e, cumulative_loss/num_examples, train_accuracy, test_accuracy))
I am getting the following errors, please help with this
MXNetError: [11:05:03] src/operator/tensor/./elemwise_binary_broadcast_op.h:68: Check failed: l == 1 || r == 1 operands could not be broadcast together with shapes [64,10] [64,2]
Stack trace returned 8 entries:
[bt] (0) 0 libmxnet.so 0x00000001077bbeb4 libmxnet.so + 20148
[bt] (1) 1 libmxnet.so 0x00000001077bbc6f libmxnet.so + 19567
[bt] (2) 2 libmxnet.so 0x0000000107cfb0e8 libmxnet.so + 5521640
[bt] (3) 3 libmxnet.so 0x000000010891cc6a MXNDListFree + 505610
[bt] (4) 4 libmxnet.so 0x000000010891b889 MXNDListFree + 500521
[bt] (5) 5 libmxnet.so 0x000000010887942a MXCustomFunctionRecord + 20666
[bt] (6) 6 libmxnet.so 0x000000010887a4d0 MXImperativeInvokeEx + 176
[bt] (7) 7 libffi.6.dylib 0x0000000105883884 ffi_call_unix64 + 76
I did try nd.reshape(output,(64,2)) but did not workout so please help me with this