Resnet does not want to float16

The “model” below is Resnet34_V2 from the gluon model zoo. I’d like to train it on AWS P3s with mixed precision. Any idea why it is returning a mxnet.base.MXNetError: Error in operator resnetv20_conv0_fwd: [17:52:13] src/operator/nn/convolution.cc:283: Check failed: (*in_type)[i] == dtype (2 vs. 0) This layer requires uniform type. Expected 'float32' v.s. given 'float16' at 'weight' ?

net = models.get_model(model, ctx=ctx, pretrained=False, classes=10)
    net.cast('float16')
    net.initialize(init=mx.init.Xavier(magnitude=2), ctx=ctx)
    net.hybridize(static_alloc=True, static_shape=True)


    # Trainer applies optimizer to a set of parameters
    trainer = gluon.Trainer(
        params=net.collect_params(),
        optimizer='sgd',
        optimizer_params={'learning_rate': learning_rate,
                          'multi_precision': True,
                          'momentum': momentum, 'wd': wd},
        kvstore=kvstore)

    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    best_accuracy = 0.0

    for epoch in range(epochs):
        # reset data iterator and metric at begining of epoch.
        train_data.reset()
        tic = time.time()
        metric.reset()
        btic = time.time()

        trainer.set_learning_rate(
            scheduler(
                epoch=epoch,
                origin_mul=float(lr_schedule[0]),
                peak_mul=float(lr_schedule[1]),
                peak_epoch=float(lr_schedule[2]),
                max_epoch=float(lr_schedule[3]),
                min_value=float(lr_schedule[4])))

        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(data=batch.data[0].astype('float16'), ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(data=batch.label[0].astype('float16'), ctx_list=ctx, batch_axis=0)
            outputs = []
            Ls = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    L = loss(z, y)
                    # store the loss and do backward after we have done forward
                    # on all GPUs for better speed on multiple GPUs.
                    Ls.append(L)
                    outputs.append(z)
                for L in Ls:
                    L.backward()
            trainer.step(batch.data[0].shape[0])
            metric.update(label, outputs)
            if i % log_interval == 0 and i > 0:
                name, acc = metric.get()
                logging.info('Epoch [%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f' %
                             (epoch, i, batch_size / (time.time() - btic), name, acc))
            btic = time.time()

        name, acc = metric.get()
        logging.info('[Epoch %d] training: %s=%f' % (epoch, name, acc))
        logging.info('[Epoch %d] time cost: %f' % (epoch, time.time() - tic))

Hi @olivcruche. I can’t reproduce your problem with the following code snippet. Do you see any issues when you run this code:

ctx = [mx.gpu(i) for i in range(8)]
learning_rate = 0.01
momentum = 0.99
wd = 0.0001
kvstore = 'device'

net = gluon.model_zoo.vision.get_model('resnet34_v2', ctx=ctx, pretrained=False, classes=10)

net.cast('float16')
net.initialize(init=mx.init.Xavier(magnitude=2), ctx=ctx)
net.hybridize(static_alloc=True, static_shape=True)

# Trainer applies optimizer to a set of parameters
trainer = gluon.Trainer(
    params=net.collect_params(),
    optimizer='sgd',
    optimizer_params={'learning_rate': learning_rate,
                      'multi_precision': True,
                      'momentum': momentum, 'wd': wd},
    kvstore=kvstore)
metric = mx.metric.Accuracy()
loss = gluon.loss.SoftmaxCrossEntropyLoss()

batch_data = nd.random.uniform(shape=(128, 3, 128, 128))
batch_label = nd.random.uniform(shape=(128,))

data = gluon.utils.split_and_load(data=batch_data.astype('float16'), ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(data=batch_label.astype('float16'), ctx_list=ctx, batch_axis=0)
outputs = []
Ls = []
with autograd.record():
    for x, y in zip(data, label):
        z = net(x)
        L = loss(z, y)
        # store the loss and do backward after we have done forward
        # on all GPUs for better speed on multiple GPUs.
        Ls.append(L)
        outputs.append(z)
    for L in Ls:
        L.backward()

trainer.step(batch_size=128)

the code below reproduces the error (on sagemaker p3.8xl conda_mxnet36)
To get the data, run the snippet below

from cifar10_utils import download_training_data
download_training_data()

cifar10_utils is there https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-python-sdk/mxnet_gluon_cifar10

import ast
import json
import logging
import math
import os
import time

import mxnet as mx
from mxnet import autograd as ag
from mxnet import gluon
from mxnet.gluon.model_zoo import vision as models


logging.basicConfig(level=logging.DEBUG)
# derived from https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/mxnet_gluon_cifar10/cifar10.py


def get_data(path, augment, num_cpus, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
    
    return mx.io.ImageRecordIter(
        path_imgrec=path,
        resize=resize,
        data_shape=data_shape,
        batch_size=batch_size,
        rand_crop=augment,
        rand_mirror=augment,
        preprocess_threads=num_cpus,
        num_parts=num_parts,
        part_index=part_index)


def get_test_data(num_cpus, data_dir, batch_size, data_shape, resize=-1):
    
    return get_data(
        path=os.path.join(data_dir, "test.rec"),
        augment=False,
        num_cpus=num_cpus,
        batch_size=batch_size,
        data_shape=data_shape,
        resize=resize,
        num_parts=1,
        part_index=0)


def get_train_data(num_cpus, data_dir, batch_size, data_shape, resize=-1, num_parts=1, part_index=0):
    
    return get_data(
        path=os.path.join(data_dir, "train.rec"),
        augment=True,
        num_cpus=num_cpus,
        batch_size=batch_size,
        data_shape=data_shape,
        resize=resize,
        num_parts=num_parts,
        part_index=part_index)


def test(ctx, net, test_data):
    
    test_data.reset()
    metric = mx.metric.Accuracy()

    for i, batch in enumerate(test_data):
        data = gluon.utils.split_and_load(data=batch.data[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(data=batch.label[0], ctx_list=ctx, batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(net(x))
        metric.update(label, outputs)
    return metric.get()


# we can for example implement a customer scheduler:
def scheduler(epoch, origin_mul, peak_mul, peak_epoch, max_epoch, min_value):
    """returns a multiplier
        - origin_mul: multiplier for epoch 0
        - peak_ml: multiplier for peak epoch
        - peak_epoch: when to peak the learning rate multiplier
        - max_epoch: when the multiplier will be equal to "min_value"
        - min_value: min value of the multipler
    """
    # linear scaling until peak_mul
    if epoch <= peak_epoch:
        return ((peak_mul - origin_mul)/ float(peak_epoch)) * epoch + origin_mul
    else:
        return peak_mul * math.exp((math.log(min_value/peak_mul)/(max_epoch - peak_epoch)) 
                                   * (epoch - peak_epoch))


def plateau_scheduler(epoch, dict_schedule):
    """dict schedule: {'epoch': lrm, 'epochN': lrmN}
    """
    keys = sorted([int(k) for k in dict_schedule.keys()])
    for k in keys:
        if epoch >= k:
            schedule = dict_schedule[str(k)]
            
    return schedule




if __name__ =='__main__':

    # GET HYPERPARAMETERS ------------------------------------------------------
    # algo
    batch_size = 128  # int(os.environ['SM_HP_BATCH_SIZE'])
    epochs = 150  # int(os.environ['SM_HP_EPOCHS'])
    learning_rate = 0.1  # float(os.environ['SM_HP_LEARNING_RATE'])
    momentum = 0.88  # float(os.environ['SM_HP_MOMENTUM'])
    log_interval = 1  # int(os.environ['SM_HP_LOG_INTERVAL'])
    wd = 0.0005  # float(os.environ['SM_HP_WD'])
    model = 'resnet18_v2'  # os.environ['SM_HP_MODEL']
    scheduler = 'plateau'  # os.environ['SM_HP_SCHEDULER']
    lr_schedule = ast.literal_eval("{'0': 0.1, '5': 1, '80': 0.1, '110': 0.01}")  # ast.literal_eval(os.environ['SM_HP_SCHEDULE'])
    

    # CONFIGURE ARCHITECTURE AND ADAPT TO DISTRIB CONTEXT ----------------------
    # know you place in the cluster
    #hosts = ast.literal_eval(os.environ['SM_HOSTS'])
    #current_host = os.environ['SM_CURRENT_HOST']
 

    ## define local hardware context
    #if len(hosts) == 1:
    kvstore = 'device'
    #    # set num workers
    #    os.environ['DMLC_NUM_WORKER'] = '1'
    #else:
    #    kvstore = 'dist_device_sync'
    #    # set num workers
    #    os.environ['DMLC_NUM_WORKER'] = str(len(hosts))
    #    os.environ['DMLC_NUM_SERVER'] = str(len(hosts))

    #logging.info('[Launch sequence] training on ' + str(len(hosts)) + ' machines')        

    ctx = [mx.gpu(i) for i in mx.test_utils.list_gpus()]

    batch_size = batch_size * max(1, len(ctx))

    logging.info('[Launch sequence] local context is ' + str(len(ctx)) + ' GPUs. '
          + ' Batch size increased to ' + str(batch_size))

    
    # BRING DATA ----------------------------------------------------------------
    # create part_indexes in case we're in multi-node architecture
    #part_index = 0
    #for i, host in enumerate(hosts):
    #    if host == current_host:
    #        part_index = i
    #        break

    train_data_dir = 'data'  # os.environ['SM_CHANNEL_TRAINING']
    val_data_dir = 'data'  # os.environ['SM_CHANNEL_VALIDATION']

    train_data = get_train_data(num_cpus=8, data_dir=train_data_dir, batch_size=batch_size,
        data_shape=(3, 32, 32), resize=-1)

    test_data = get_test_data(num_cpus=8, data_dir=val_data_dir, batch_size=batch_size,
        data_shape=(3, 32, 32), resize=-1)


    # INITIALIZE AND LAUNCH TRAINING --------------------------------------------
    # Collect all parameters from net and its children, then initialize them.
    net = models.get_model(model, ctx=ctx, pretrained=False, classes=10)
    net.cast('float16')
    net.initialize(init=mx.init.Xavier(magnitude=2), ctx=ctx)
    net.hybridize(static_alloc=True, static_shape=True)


    # Trainer applies optimizer to a set of parameters
    trainer = gluon.Trainer(
        params=net.collect_params(),
        optimizer='sgd',
        optimizer_params={'learning_rate': learning_rate,
                          'multi_precision': True,
                          'momentum': momentum, 'wd': wd},
        kvstore=kvstore)

    metric = mx.metric.Accuracy()
    loss = gluon.loss.SoftmaxCrossEntropyLoss()

    best_accuracy = 0.0


    for epoch in range(epochs):
        # reset data iterator and metric at begining of epoch.
        train_data.reset()
        tic = time.time()
        metric.reset()
        btic = time.time()

        if scheduler == 'super':
            lr = learning_rate * scheduler(
                epoch=epoch,
                origin_mul=float(lr_schedule[0]),
                peak_mul=float(lr_schedule[1]),
                peak_epoch=float(lr_schedule[2]),
                max_epoch=float(lr_schedule[3]),
                min_value=float(lr_schedule[4]))
        
        elif scheduler == 'plateau':
            lr = learning_rate * plateau_scheduler(epoch=epoch, dict_schedule=lr_schedule)

        print('LEARNING RATE SET TO ' + str(lr))
        trainer.set_learning_rate(lr*learning_rate)
            

        for i, batch in enumerate(train_data):
            data = gluon.utils.split_and_load(data=batch.data[0].astype('float16'), ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(data=batch.label[0].astype('float16'), ctx_list=ctx, batch_axis=0)
            outputs = []
            Ls = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    L = loss(z, y)
                    # store the loss and do backward after we have done forward
                    # on all GPUs for better speed on multiple GPUs.
                    Ls.append(L)
                    outputs.append(z)
                for L in Ls:
                    L.backward()
            trainer.step(batch.data[0].shape[0])
            metric.update(label, outputs)
            if i % log_interval == 0 and i > 0:
                name, acc = metric.get()
                logging.info('Epoch [%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f' %
                             (epoch, i, batch_size / (time.time() - btic), name, acc))
            btic = time.time()

        name, acc = metric.get()
        logging.info('[Epoch %d] training: %s=%f' % (epoch, name, acc))
        logging.info('[Epoch %d] time cost: %f' % (epoch, time.time() - tic))
        
        if epoch % 10 == 0 or epoch in [0, epochs - 1]:
            name, val_acc = test(ctx, net, test_data)
            logging.info('[Epoch %d] validation: %s=%f' % (epoch, name, val_acc))

full error log:

MXNetError: Error in operator resnetv23_conv0_fwd: [11:39:17] src/operator/nn/convolution.cc:283: Check failed: (*in_type)[i] == dtype (2 vs. 0) This layer requires uniform type. Expected 'float32' v.s. given 'float16' at 'weight'

Stack trace returned 10 entries:
[bt] (0) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x36161a) [0x7f009fd3a61a]
[bt] (1) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x361c31) [0x7f009fd3ac31]
[bt] (2) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x62c15c) [0x7f00a000515c]
[bt] (3) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2aa1d7e) [0x7f00a247ad7e]
[bt] (4) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2aab177) [0x7f00a2484177]
[bt] (5) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2aabc3a) [0x7f00a2484c3a]
[bt] (6) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2ac6427) [0x7f00a249f427]
[bt] (7) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2acab75) [0x7f00a24a3b75]
[bt] (8) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2ad2364) [0x7f00a24ab364]
[bt] (9) /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2ad3951) [0x7f00a24ac951]

OMG how possible - I did the same error as https://discuss.mxnet.apache.org/t/simple-float16-example-not-working/2479… meaning not float16’ing my eval function

adding a “precision” parameter in the eval function fixed the thing:

def test(ctx, net, test_data, precision='float32'):
    
    test_data.reset()
    metric = mx.metric.Accuracy()

    for i, batch in enumerate(test_data):
        data = gluon.utils.split_and_load(data=batch.data[0].astype(precision), ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(data=batch.label[0].astype(precision), ctx_list=ctx, batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(net(x))
        metric.update(label, outputs)
    return metric.get()
1 Like