Hi,
I am a bit confused about accessing the weights in a network.
>>> net = model_zoo.get_model("VGG16", pretrained=True)
>>> all_params = net.collect_params()
all params
is of type ParameterDict
.
When I iterate over its keys and values:
>>> for k, v in net.collect_params().items():
>>> print(f"{k:>30} - {v.name}")
vgg0_conv0_weight - vgg0_conv0_weight
vgg0_conv0_bias - vgg0_conv0_bias
vgg0_conv1_weight - vgg0_conv1_weight
vgg0_conv1_bias - vgg0_conv1_bias
vgg0_conv2_weight - vgg0_conv2_weight
... etc
I see the keys of the ParameterDict
and the name of the Parameter
.
However, when trying to access the value by its key, like:
>>> net.collect_params().get("vgg0_conv0_weight").data()
I get the error:
RuntimeError: Parameter 'vgg0_vgg0_conv0_weight' has not been initialized.
When I remove the prefix VGG0_, it works:
>>> net.collect_params().get("conv0_weight").data()
[[[[-0.03772613 -0.06753712 0.09550434]
[-0.12510349 -0.1030491 0.03194679]
[ 0.04482884 0.02893391 0.07695705]]
... etc
<NDArray 64x3x3x3 @cpu(0)>
In other words, when requesting the keys by:
>>> net.collect_params().keys()
you cannot use this result to query a value by key, without manually removing a prefix. Am I doing something wrong? Is this intended behaviour or a bug?
Best,
Blake