Make while_loop work with deferred intialization when out shape is fixed?

Context: I am using a contrib.while_loop in my network to multiply a squared matrix by itself repeatedly and dynamically.

Issue: The issue I meet is that this operation comes from the deferred initialization (some convolution / dense layers before the operation):

ValueError: Deferred initialization failed because shape cannot be inferred. Cannot decide shape for the following arguments (0s in shape means unknown dimensions). Consider providing them as input:

This makes sense because the output shape of the while loop is in general unknown. However in my case, the output shape is known (shape of the input matrix).

What I tried?

  • Based on this answer, I use a Block and not HybridBlock for the operation but the constraint is that I have to write a sequential network which is not very friendly in my case (auto encoder structure).
  • I began looking at the infer_shape HybridBlock method to try overiding it but I am getting stucked.

Do you have any ideas to make this work?

I am on the 1.6.0 version.