Python experimentation and Java/Scala production systems

Hi,
I am working on a project where I want to use mxnet for some predictions, but I am very new to the framework and definitely more focused on the ML side than the system side. In particular, it’s pretty easy for me to understand how to experiment with python/gluon but I am kind of stuck in trying to understand how to bring my networks in the production system my company is using.
If the production system would be in python, everything would be kind of easy but it is in Java (Scala also acceptable) and this creates a lot of confusion. I would like to provide “something” (i.e., some kind of serialized network) that can retrain the network with new data over time and can be used to make predictions, but without rewriting it in Java/Scala.
I really want to avoid rewriting code in two languages, because this will dramatically slow down my iterations over the model.

I see that serialized models can be read/imported with different language bindings and my understanding is that using the network to make prediction is trivial if it is pre-trained (but in my case it’s not – it will need to be trained from scratch) but I didn’t find any information about the retraining of the network.

My questions are:

  • given my python/gluon code, is there a way to run the same network in Java/Scala? if yes, it just needs to be serialized or it requires a different procedure?
  • what are the limitations of the different approaches in terms of network architecture?
  • can I retrain my python-define network from Java/Scala code?
  • I have some code doing some operations “around” the network, if I serialize the network this will be lost (I guess), should I make sure all my logic is in the forward method?
  • is there some material on how to handle mxnet networks in different production environments? (I didn’t find anything satisfactory so far)

You can export the model in your python/gluon code and load the same model into Java. The Java API is only for inference right now, so you won’t be able to retrain the same network in Java: https://medium.com/apache-mxnet/introducing-java-apis-for-deep-learning-inference-with-apache-mxnet-8406a698fa5a Instead you could retrain it in Scala: https://mxnet.incubator.apache.org/versions/master/api/scala/module.html

How do you want to run your model in production? There is a model server for MXNet, which is an easy tool to serve deep learning models: https://github.com/awslabs/mxnet-model-server

1 Like

My plan was to have a simple class loading data and making predictions, not sure about the model server, but thanks for the pointer.

Can you also help me understand what get transferred in the ‘exported’ model and what is not?
Do you have a link/wiki/tutorial/… for this?

Thanks.

1 Like

Here is a great tutorial: https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/save_load_params.html
You can either use the function save_parameters to export only parameters or use the function export to export model and parameters .

1 Like

I think that the hybridization (?) of the network is what I need to do.

Export paramters and architecture from python/gluon, load the network in Scala (Java is only for inference) and then train the network in Scala.
Finally re-serialize the model and start serving predictions from either Java or Scala.

I just need to make sure there is a good way to put some of the code I have around the network inside the forward pass and somehow use my custom logic for some of the updates but it should not be an issue.

1 Like

If you want to store to a file the architecture network itself, then yes, it should be hybridizable.

1 Like