I try to implement SRU in MXNet, and I find it is difficult to implement a for-loop.
Here is what I have done, but the question is, for i in ndarray
and for i in sym
behaves differently, thus the code not working.
import mxnet as mx
class SRU(mx.gluon.HybridBlock):
def __init__(self,hidden_size,use_bias=True,layout='TNC',activation=None,params=None,prefix=None, **kwargs):
assert layout == 'TNC' , "TNC support only, use swapaxes to change the layout first."
assert activation==None, "You should manually add activaion layers rather than specific it here."
super(SRU, self).__init__(params=params,prefix=prefix,**kwargs)
with self.name_scope():
self.W=mx.sym.var(self.prefix+"_W_weight",shape=(hidden_size,hidden_size),init=mx.init.Xavier())
self.Wfr=mx.sym.var(self.prefix+"_W_fr_weight",shape=(hidden_size,hidden_size*2),init=mx.init.Xavier())
self.v=mx.sym.var(self.prefix+"_v_fr_weight",shape=(1,1,hidden_size*2),init=mx.init.Xavier())#shape is used for broadcast_add which actual shape is (batch_size,hidden_size*2)
self.b=mx.sym.var(self.prefix+"_b_fr_weight",shape=(1,1,hidden_size*2),init=mx.init.Zero())
def gen(self,F,Wx_plus_b_for_fr,Wx,x,c):
for Wxpbt,Wxt,xt in zip(Wx_plus_b_for_fr,Wx,x):
ft,rt=F.sigmoid(F.broadcast_add(Wxpbt,F.broadcast_mul(c.tile(2),self.v))).split(2,axis=0)
c=F.broadcast_add(ft*c,(1-ft)*Wxt)
yield F.broadcast_add(rt*c,(1-rt)*xt)
self.c=c
def hybrid_forward(self,F,x,c,*args,**params):
Wx_plus_b_for_fr = F.broadcast_add(F.dot(x,self.Wfr),self.b)
Wx = F.dot(x,self.W)
ht_f = F.stack(*self.gen(F,Wx_plus_b_for_fr,Wx,x,c))
return ht_f,self.c
I want to know what is the best solution for mxnet to implement a for-loop.
thanks.