Typically ``data_batch.provide_label``.
"""
assert self.binded, 'call bind before switching bucket'
if not bucket_key in self._buckets:
symbol, data_names, label_names = self._sym_gen(bucket_key)
module = Module(symbol, data_names, label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
self._buckets[bucket_key] = module
self._curr_module = self._buckets[bucket_key]
self._curr_bucket_key = bucket_key
def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),),
force_init=False):
I want to set the grad_req
for all my parameters to 'add'
by using specifying it in the bind()
call on my top-level module.
The problem is that I use a BucketingModule
. So when I call forward
, it automatically calls the switch_bucket
method, which then calls bind
again without specifying grad_req
and therefore setting it to 'write'
.
Is this intended behavior? Can I somehow work around this while still using bucketing?
I already opened an issue on GitHub here: https://github.com/apache/incubator-mxnet/issues/10904
Probably, the simplest solution would be to inherit from the BucketingModule
and override the bind()
argument to be ‘add’ by default.
Something like:
class MyBucketingModule:
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='add'):
super(MyBucketingModule, self).bind(label_shapes, for_training, inputs_need_grad, force_rebind, shared_module, grad_req)