Modify tutorial chapter14 to muti-gpu

Hi.
I am writing to modify gluon.ai chapter14 pix2pix-gan to muti-gpu under the instruction from chapter7.But my code are unstable and get a visually low output. the official (one gpu) code are in gluon.ai chapter14 pix2pix.
here is my modified code,erro and picture

m_ctx contains 4gpus
ctx contains 1gpu
batch of 4gpu =40
batch of 1gpu = 10

def train():
image_pool = ImagePool(pool_size)
metric = mx.metric.CustomMetric(facc)

stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
logging.basicConfig(level=logging.DEBUG)
netG.collect_params().reset_ctx(m_ctx)
netD.collect_params().reset_ctx(m_ctx)
trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
for epoch in range(epochs):
    tic = time.time()
    btic = time.time()
    train_data.reset()
    iter = 0
    for batch in train_data:
        real_in = gluon.utils.split_and_load(batch.data[0], ctx_list=m_ctx, )
        real_out = gluon.utils.split_and_load(batch.data[1], ctx_list=m_ctx)
        ############################
        # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
        ###########################
        fake_out = [netG(X) for X in real_in]
        fake_concat = [image_pool.query(nd.concat(X, Y, dim=1)) for X, Y in
                       zip(real_in, fake_out)]

        with autograd.record():
            output = [netD(X) for X in fake_concat]
            fake_label = []
            for i in range(4):
                fake_label.append(nd.zeros(output[i].shape, ctx=output[i].context)
            errD_fake = [GAN_loss(X, Y) for X, Y in zip(output, fake_label)]
            metric.update([x for x in fake_label], [x for x in output])
            real_concat = [nd.concat(X, Y, dim=1) for X, Y in zip(real_in, real_out)]
            output = [netD(X) for X in real_concat]
            real_label = []
            for i in range(4):
                real_label.append(nd.ones(output[i].shape, ctx=output[i].context))
            errD_real = [GAN_loss(X, Y) for X, Y in zip(output, real_label)]
            errD = [((X + Y) * 0.5) for X, Y in zip(errD_real, errD_fake)]

            autograd.backward(errD)
            metric.update([x for x in real_label], [x for x in output])
        trainerD.step(batch.data[0].shape[0])

        ############################
        # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
        ###########################
        with autograd.record():

            fake_out = [netG(X) for X in real_in]
            fake_concat = [nd.concat(X, Y, dim=1) for X, Y in zip(real_in, fake_out)]
            output = [netD(X) for X in fake_concat]
            real_label = []
            for i in range(4):
                real_label.append(nd.ones(shape=output[i].shape, ctx=output[i].context)) 
            errG = [(GAN_loss(A, B) + L1_loss(C, D) * 1000)
                    for A, B, C, D in zip(output, real_label, real_out, fake_out)]
            autograd.backward(errG)
        trainerG.step(batch.data[0].shape[0])

        # Print log infomation every ten batches
        if iter % 10 == 0:
            name, acc = metric.get()
            logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
            logging.info(
                'discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                % (nd.mean(errD[0]).asscalar(),
                   nd.mean(errG[0]).asscalar(), acc, iter, epoch))
        iter = iter + 1
        btic = time.time()
    #
    name, acc = metric.get()
    metric.reset()
    logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
    logging.info('time: %f' % (time.time() - tic))

    fake_img = fake_out[0]
    visualize(fake_img[0])
    plt.show()

erro (may get sometimes):
INFO:root:
binary training acc at epoch 22: facc=0.999336
INFO:root:time: 7.901645
Traceback (most recent call last):
File “/root/PycharmProjects/untitled/ssd_denoising/denoiser_mult.py”, line 415, in
train()
File “/root/PycharmProjects/untitled/ssd_denoising/denoiser_mult.py”, line 329, in train
metric.update([x for x in fake_label], [x for x in output])
File “/root/anaconda3/lib/python3.6/site-packages/mxnet/metric.py”, line 1376, in update
pred = pred.asnumpy()
File “/root/anaconda3/lib/python3.6/site-packages/mxnet/ndarray/ndarray.py”, line 1972, in asnumpy
ctypes.c_size_t(data.size)))
File “/root/anaconda3/lib/python3.6/site-packages/mxnet/base.py”, line 252, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [03:41:31] src/operator/tensor/./…/…/common/…/operator/mxnet_op.h:622: Check failed: err == cudaSuccess (33 vs. 0) Name: mxnet_generic_kernel ErrStr:invalid resource handle

Stack trace returned 10 entries:
[bt] (0) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x382eea) [0x7f6d4dcdbeea]
[bt] (1) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x383521) [0x7f6d4dcdc521]
[bt] (2) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x44bcd35) [0x7f6d51e15d35]
[bt] (3) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x44bddc6) [0x7f6d51e16dc6]
[bt] (4) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x44c552b) [0x7f6d51e1e52b]
[bt] (5) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2b88698) [0x7f6d504e1698]
[bt] (6) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2ae9137) [0x7f6d50442137]
[bt] (7) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2ae9137) [0x7f6d50442137]
[bt] (8) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2ae9189) [0x7f6d50442189]
[bt] (9) /root/anaconda3/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x2af2bc4) [0x7f6d5044bbc4]

my output and official are after 100 epochs
myplot1
here is my output and official output is in floor. 2

I have tried to do multgpu version for this code myself, and found that class ImagePool is a culprit of a wrong behaviour in my case. The thing is that it is used as a cache, and if used with multigpu case, it starts containing images from different contexts. This eventually produces an exception once an image from one context is passed to an execution process of another context.

I have just created a dictionary of ImagePool class with different contexts as keys. You can probably do it more elegantly, but it works that way as well. So, the first line of train method looks like

image_pools = {c: ImagePool(pool_size) for c in ctx}

And here is the full example for 4 gpu:

from __future__ import print_function
import os
import matplotlib as mpl
import tarfile
import matplotlib.image as mpimg
from matplotlib import pyplot as plt

import mxnet as mx
from mxnet import autograd, gluon
from mxnet import ndarray as nd
from mxnet.gluon import utils
import numpy as np

from ImagePool import ImagePool
from model import UnetGenerator, Discriminator
from datetime import datetime
import time
import logging

epochs = 100

use_gpu = True
ctx = [mx.gpu(0), mx.gpu(1), mx.gpu(2), mx.gpu(3)] if use_gpu else [mx.cpu()]
batch_size = 10 * len(ctx)
lr = 0.0002 * len(ctx)
beta1 = 0.5
lambda1 = 100

pool_size = 50

dataset = 'facades'

img_wd = 256
img_ht = 256
train_img_path = '%s/train' % (dataset)
val_img_path = '%s/val' % (dataset)


def download_data(dataset):
    if not os.path.exists(dataset):
        url = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz' % (dataset)
        os.mkdir(dataset)
        data_file = utils.download(url)
        with tarfile.open(data_file) as tar:
            tar.extractall(path='.')
        os.remove(data_file)


def load_data(path, batch_size, is_reversed=False):
    img_in_list = []
    img_out_list = []
    for path, _, fnames in os.walk(path):
        for fname in fnames:
            if not fname.endswith('.jpg'):
                continue
            img = os.path.join(path, fname)
            img_arr = mx.image.imread(img).astype(np.float32)/127.5 - 1
            img_arr = mx.image.imresize(img_arr, img_wd * 2, img_ht)
            # Crop input and output images
            img_arr_in, img_arr_out = [mx.image.fixed_crop(img_arr, 0, 0, img_wd, img_ht),
                                       mx.image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]
            img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2,0,1)),
                                       nd.transpose(img_arr_out, (2,0,1))]
            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape),
                                       img_arr_out.reshape((1,) + img_arr_out.shape)]
            img_in_list.append(img_arr_out if is_reversed else img_arr_in)
            img_out_list.append(img_arr_in if is_reversed else img_arr_out)

    return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)],
                             batch_size=batch_size)

download_data(dataset)
train_data = load_data(train_img_path, batch_size, is_reversed=True)
val_data = load_data(val_img_path, batch_size, is_reversed=True)


def visualize(img_arr):
    plt.imshow(((img_arr.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')


def preview_train_data():
    img_in_list, img_out_list = train_data.next().data
    for i in range(4):
        plt.subplot(2,4,i+1)
        visualize(img_in_list[i])
        plt.subplot(2,4,i+5)
        visualize(img_out_list[i])
    plt.show()


def param_init(param):
    if param.name.find('conv') != -1:
        if param.name.find('weight') != -1:
            param.initialize(init=mx.init.Normal(0.02), ctx=ctx)
        else:
            param.initialize(init=mx.init.Zero(), ctx=ctx)
    elif param.name.find('batchnorm') != -1:
        param.initialize(init=mx.init.Zero(), ctx=ctx)
        # Initialize gamma from normal distribution with mean 1 and std 0.02
        if param.name.find('gamma') != -1:
            param.set_data(nd.random_normal(1, 0.02, param.data(ctx[0]).shape))


def network_init(net):
    for param in net.collect_params().values():
        param_init(param)


def set_network():
    # Pixel2pixel networks
    netG = UnetGenerator(in_channels=3, num_downs=8)
    netD = Discriminator(in_channels=6)

    # Initialize parameters
    network_init(netG)
    network_init(netD)

    # trainer for the generator and the discriminator
    trainerG = gluon.Trainer(netG.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})
    trainerD = gluon.Trainer(netD.collect_params(), 'adam', {'learning_rate': lr, 'beta1': beta1})

    return netG, netD, trainerG, trainerD


# Loss
GAN_loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
L1_loss = gluon.loss.L1Loss()

netG, netD, trainerG, trainerD = set_network()


def facc(label, pred):
        pred = pred.ravel()
        label = label.ravel()
        return ((pred > 0.5) == label).mean()


def train():
    image_pools = {c: ImagePool(pool_size) for c in ctx}
    metric = mx.metric.CustomMetric(facc)

    stamp = datetime.now().strftime('%Y_%m_%d-%H_%M')
    logging.basicConfig(level=logging.DEBUG)

    for epoch in range(epochs):
        tic = time.time()
        btic = time.time()
        train_data.reset()
        iter = 0
        for batch in train_data:
            ############################
            # (1) Update D network: maximize log(D(x, y)) + log(1 - D(x, G(x, z)))
            ###########################
            batch_size_loaded = batch.data[0].shape[0]
            real_in = utils.split_and_load(batch.data[0], ctx)
            real_out = utils.split_and_load(batch.data[1], ctx)

            fake_out = [netG(ri) for ri in real_in]
            fake_concat = [image_pools[ri.context].query(nd.concat(ri, fo, dim=1))
                           for ri, fo in zip(real_in, fake_out)]
            errD = []
            errG = []

            for fc, ri, ro in zip(fake_concat, real_in, real_out):
                with autograd.record():
                    # Train with fake image
                    # Use image pooling to utilize history images
                    output = netD(fc)
                    fake_label = nd.zeros(output.shape, ctx=fc.context)
                    errD_fake = GAN_loss(output, fake_label)
                    metric.update([fake_label,], [output,])

                    # Train with real image
                    real_concat = nd.concat(ri, ro, dim=1)
                    output = netD(real_concat)
                    real_label = nd.ones(output.shape, ctx=fc.context)
                    errD_real = GAN_loss(output, real_label)
                    err = (errD_real + errD_fake) * 0.5
                    errD.append(err)
                    metric.update([real_label,], [output,])

            for err in errD:
                err.backward()

            trainerD.step(batch_size_loaded)

            ############################
            # (2) Update G network: maximize log(D(x, G(x, z))) - lambda1 * L1(y, G(x, z))
            ###########################
            for ri, ro in zip(real_in, real_out):
                with autograd.record():
                    fake_out = netG(ri)
                    fake_concat = nd.concat(ri, fake_out, dim=1)
                    output = netD(fake_concat)
                    real_label = nd.ones(output.shape, ctx=ri.context)
                    err = GAN_loss(output, real_label) + L1_loss(ro, fake_out) * lambda1
                    errG.append(err)

            for err in errG:
                err.backward()

            trainerG.step(batch_size_loaded)

            # Print log information every ten batches
            if iter % 10 == 0:
                name, acc = metric.get()
                avg_errD = mx.nd.zeros((1,), ctx=ctx[0])
                avg_errG = mx.nd.zeros((1,), ctx=ctx[0])

                for err in errD:
                    avg_errD += err.mean().as_in_context(avg_errD.context)

                for err in errG:
                    avg_errG += err.mean().as_in_context(avg_errG.context)

                logging.info('speed: {} samples/s'.format(batch_size / (time.time() - btic)))
                logging.info('discriminator loss = %f, generator loss = %f, binary training acc = %f at iter %d epoch %d'
                         %(avg_errD.asscalar(),
                           avg_errG.asscalar(), acc, iter, epoch))
            iter = iter + 1
            btic = time.time()

        name, acc = metric.get()
        metric.reset()
        logging.info('\nbinary training acc at epoch %d: %s=%f' % (epoch, name, acc))
        logging.info('time: %f' % (time.time() - tic))

        # Visualize one generated image for each epoch
        fake_img = fake_out[0]
        visualize(fake_img)
        plt.show()

train()


def print_result():
    num_image = 4
    img_in_list, img_out_list = val_data.next().data
    for i in range(num_image):
        img_in = nd.expand_dims(img_in_list[i], axis=0)
        plt.subplot(2,4,i+1)
        visualize(img_in[0])
        img_out = netG(img_in.as_in_context(ctx[0]))
        plt.subplot(2,4,i+5)
        visualize(img_out[0])
    plt.show()


print_result()

Final losses:

INFO:root:time: 2.431973
INFO:root:speed: 152.46581218329467 samples/s
INFO:root:discriminator loss = 1.934726, generator loss = 61.413872, binary training acc = 0.594510 at iter 0 epoch 99
INFO:root:
binary training acc at epoch 99: facc=0.736928
INFO:root:time: 2.345976

Resulting image:

image

1 Like