Thanks for your reply, Vishaal, and also thanks for the interesting link, which did provide me a very good idea about what’s eating up memory - very helpful!
I believe offloading and then prefetching again would necessarily happen beyond the scope of the MXNet APIs, right?
Regarding the specific model, it’s actually a very simplistic model. Here’s the core of it (it’s a recommender system):
user = mx.symbol.Variable("user")
item = mx.symbol.Variable("item")
score = mx.symbol.Variable("softmax_label")
user_embed = mx.symbol.Embedding(name="user_embed", data=user,
input_dim=max_users, output_dim=embed_size)
item_embed = mx.symbol.Embedding(name="item_embed", data=item,
input_dim=max_items, output_dim=embed_size)
user = mx.symbol.L2Normalization(user_embed)
item = mx.symbol.L2Normalization(item_embed)
dot = user * item
dot = mx.symbol.sum(dot, axis=1)
dot = mx.symbol.Flatten(dot)
pred = mx.symbol.LinearRegressionOutput(data=dot, label=score)
The purpose of keeping the model so very simple is basically to keep the embeddings with a high “semantic” load. Also, embed_size
is fixed to 64 (according to experimentation, it can’t go much lower).
The problem is, I believe, not really within the topology itself, but rather with the amount of embeddings: with some ~7M items and ~200k users (max_items
and max_users
, respectively), the embedding matrices grow very large (the model itself uses about 2.5GB). I could be using sparse embeddings, but those would save me less than 10% of the total memory currently being used (10% means in this case that I will stumble into the memory problem again in some months, as user/item base grows).
I’m currently using a quite large batch size (50k batch size), but running some tests revealed that reducing batch size helps only to a very limited extent (again, in the 10% range).
I have been exploring using mixed precision training (according to https://mxnet.incubator.apache.org/faq/float16.html), but so far no luck… In fact, if I use the multi_precision=True
flag within the optimizer, it actually seems to use up more memory (!). I’m still exploring this, however.
Thanks for your message!