Why does NDArrayIter use Float32 for the label DataDesc?

I had trouble getting a simple regression model to work because the NDArrayIterator was passing the data type as Float64 but the label type as Float32. (The values are, in fact, Float64 as they come from a postgresql database.) This incompatibility caused problems downstream when trying to run fit() with LinearRegressionOutput as the last layer.

I traced the problem down to this bit of Scala code, where NDArrayIter is constructing what will become the DataDesc objects for use in provideDataDesc and provideLabelDesc.

------ in NDArrayIter -------

def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = IndexedSeq.empty,
       dataBatchSize: Int = 1, shuffle: Boolean = false,
       lastBatchHandle: String = "pad",
       dataName: String = "data", labelName: String = "label") {
this(IO.initDataDesc(data, allowEmpty = false, dataName,
  if (data == null || data.isEmpty)  MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
  IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
  dataBatchSize, shuffle, lastBatchHandle)

}

As you can see, it respects the actual type of the data (if provided), but never respects the type of the label (if provided). Instead, it ALWAYS treats the labels as MX_REAL_TYPE (i.e., Float32), regardless of what it really is.

Shouldn’t the labels be handled in the same was as the data? That is, respect the type if provided, and fallback to MX_REAL_TYPE only if not provided?

This was a change in PR #13678 by @piyushghai. I can’t tell why the same logic hasn’t been added for the label. Maybe Piyush can answer.

Hi,

That’s a good catch. I’d thought about it and had not anticipated a use case for the labels to be of precision type Double.Generally they are discrete values/class labels. :slight_smile:

Nevertheless, I will raise a PR to respect the label’s datatype and that should help unblock you :slight_smile:

I’ll post a link of the PR here when I raise it.

That’s great, Piyush. I’m downcasting to Float32 for both data and labels for now (and that works in my case), but it does seem like something that should be changed.

Thanks,

Kenner

1 Like

Here we go. https://github.com/apache/incubator-mxnet/pull/14038
This PR should help unblock you :slight_smile:

1 Like