Hi, someone could explain the use for mx.symbol.group?
Does it means that the symbol will have multiple outputs?
it’s parellarl output right ? rather than output one after another?
Thankyou
Hello @xysong1201,
mx.symbol.Group()
enables parallel output thus a model with multiple heads.
One examplary usage is an AlphaZero-like model head:
You can also find a visualization of the corresponding NN architecture.
You can do the same in the Gluon-API like this:
def hybrid_forward(self, F, x):
out = self.body(x)
value = self.value_head(out)
policy = self.policy_head(out)
return [value, policy]
and define a custom linear combination of the loss: