Attempting to use augmentation during training

I try to implement custom version of contrastive learning. I get a network output from a sample, then a second output from augmented sample and penalize loss between those outputs.
I try to use gluon transforms: from mxnet.gluon.data.vision import transforms.

First problem is that opencv does not work on GPU. I made a workaround by moving samples to CPU and then back to GPU.
Second problem is that image transforms only work on (W,H,3) arrays. Added artificial repeating and cutting back.
And the third problem, which I could not solve, was (if I understand right) that MXNet tries to backpropagate through transforms in some way. I tried to use stop_gradient with no effect. The lines where error happens are

                    _, z1 = net(x)
                    # here I use transforms, which cause error during backward pass
                    x_aug = augment_monochrome(x, joint_transform, noise_ampl=.2, opencv_gpu_fix=True)
                    _, z1aug = net(x_aug)
                    z1augStopGrad = F.stop_gradient(z1aug)
                    loss = 1. - _cosine_similarity(z1, z1augStopGrad)
                    # here the error happens
                    loss.backward()

How can I make gradient to flow only through a main branch without sample augmentation?
P.S. Can i make transforms faster than using F.stack(*[F.swapaxes(...) ...]) loop? The documentation 1.8.0 says that for all transform functions

Inputs:

    data: input tensor with (C x H x W) or (N x C x H x W) shape.

but when I try to feed samples without loop, in (N x C x H x W) form, I get error
File "/home/ai/store/sources/incubator-mxnet/python/mxnet/image/image.py", line 588, in random_size_crop

 h, w, _ = src.shape

And only (H x W x C) format actually works.

SYSTEM, CUDA, MXNET INFO:

$ uname -a

Linux NeuralPC 5.11.0-25-generic #27~20.04.1-Ubuntu SMP Tue Jul 13 17:41:23 UTC 2021 x86_64 x86_64 x86_64 GNU/Linux

CUDA version is 11.4:

$ nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Wed_Jun__2_19:15:15_PDT_2021
Cuda compilation tools, release 11.4, V11.4.48
Build cuda_11.4.r11.4/compiler.30033411_0

MXNet is built from git:

$ git log

commit 8f1c38adb1fbfe04ec335318bf127840cf14e142 (HEAD -> v1.x, origin/v1.x)
...

$ python3 -c ‘import mxnet; print(mxnet.version)’

1.9.0

Here is the error:

Traceback (most recent call last):
  File "/home/ai/learning/contrastive_learning/my_contrastive_learning/contrastive_learning.py", line 183, in <module>
    contrastive_pretrain(net, train_data, epoch=4)
  File "/home/ai/learning/contrastive_learning/my_contrastive_learning/contrastive_learning.py", line 109, in contrastive_pretrain
    loss.backward()
  File "/home/ai/store/sources/incubator-mxnet/python/mxnet/ndarray/ndarray.py", line 2869, in backward
    check_call(_LIB.MXAutogradBackwardEx(
  File "/home/ai/store/sources/incubator-mxnet/python/mxnet/base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: Traceback (most recent call last):
  File "/home/ai/store/sources/incubator-mxnet/src/nnvm/gradient.cc", line 213
MXNetError: Operator _cvimresize is non-differentiable because it didn't register FGradient attribute.

The code:

 # based on https://mxnet.apache.org/versions/1.8.0/api/python/docs/tutorials/packages/gluon/image/mnist.html
from __future__ import print_function
import mxnet as mx
import numpy as np

from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag

import mxnet.ndarray as F
from mxnet.gluon.data.vision import transforms

ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
EPS_ARR = F.array(np.array([1e-12])).as_in_context(ctx)


def train(net, train_data, epoch=100):
    metric = mx.metric.Accuracy()
    softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    for i in range(epoch):
        train_data.reset()
        for batch in train_data:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=[ctx], batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=[ctx], batch_axis=0)
            outputs = []
            with ag.record():
                for x, y in zip(data, label):
                    z, _ = net(x)
                    # Computes softmax cross entropy loss.
                    loss = softmax_cross_entropy_loss(z, y)
                    # Backpropagate the error for one iteration.
                    loss.backward()
                    outputs.append(z)
            metric.update(label, outputs)
            trainer.step(batch.data[0].shape[0])
        name, acc = metric.get()
        metric.reset()
        print('training acc at epoch %d: %s=%f' % (i, name, acc))


def augment_monochrome(x,
                       joint_transform,
                       noise_ampl=.2,
                       opencv_gpu_fix=True):
    if opencv_gpu_fix:
        x_aug0 = x.as_in_context(mx.cpu())
    else:
        x_aug0 = x
    x_aug = F.repeat(x_aug0, repeats=3, axis=1)
    x_aug_ = F.stack(*[F.swapaxes(
        (1 + noise_ampl * np.random.normal()) * joint_transform(
            F.swapaxes(x_aug[i, ...], 0, 2)
        ) + noise_ampl * np.random.normal(), 0, 2
    ) for i in range(batch_size)])
    if opencv_gpu_fix:
        x_aug = x_aug_.as_in_context(ctx)
    else:
        x_aug = x_aug_
    x_aug = x_aug[:, :1, :, :]
    x_aug = F.clip(x_aug, 0., 1.)
    return x_aug


joint_transform = transforms.Compose([
    # transforms.RandomBrightness(.2),
    # transforms.RandomContrast(.3),
    transforms.RandomRotation(angle_limits=(-30, 30), zoom_in=True),
    transforms.RandomResizedCrop(size=28, scale=(.7, 1.), ratio=(.8, 1.25))
])


# based on incubator-mxnet/python/mxnet/gluon/loss.py CosineEmbeddingLoss class
def _cosine_similarity(x, y, axis=-1):
    x_norm = F.norm(x, axis=axis).reshape((-1, 1))
    y_norm = F.norm(y, axis=axis).reshape((-1, 1))
    x_dot_y = F.sum(x * y, axis=axis).reshape((-1, 1))
    EPS_ARR = F.array(np.array([1e-12])).as_in_context(ctx)
    return x_dot_y / F.broadcast_maximum(x_norm * y_norm, EPS_ARR)


def contrastive_pretrain(net, train_data, epoch=100):
    for i in range(epoch):
        train_data.reset()
        for batch in train_data:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=[ctx], batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=[ctx], batch_axis=0)
            with ag.record():
                for x, y in zip(data, label):
                    _, z1 = net(x)
                    x_aug = augment_monochrome(x, joint_transform, noise_ampl=.2, opencv_gpu_fix=True)
                    _, z1aug = net(x_aug)
                    z1augStopGrad = F.stop_gradient(z1aug)
                    loss = 1. - _cosine_similarity(z1, z1augStopGrad)
                    loss.backward()
            trainer.step(batch.data[0].shape[0])


def validate(net, val_data):
    # Use Accuracy as the evaluation metric.
    metric = mx.metric.Accuracy()
    val_data.reset()
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=[ctx], batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=[ctx], batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(net(x)[0])
        metric.update(label, outputs)
    validation_metric = metric.get()
    return validation_metric


class Net(gluon.Block):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            self.conv1 = nn.Conv2D(2, kernel_size=(5, 5))
            self.pool1 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2))
            self.conv2 = nn.Conv2D(2, kernel_size=(5, 5))
            self.pool2 = nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2))
            self.fc1 = nn.Dense(16)
            self.fc2 = nn.Dense(10)  # this is equal to number of classes

    def forward(self, x):
        x = self.pool1(F.tanh(self.conv1(x)))
        x = self.pool2(F.tanh(self.conv2(x)))
        x = x.reshape((0, -1))
        x_last = F.tanh(self.fc1(x))
        out = F.tanh(self.fc2(x_last))
        return out, x_last


if __name__ == '__main__':
    # Fixing the random seed
    mx.random.seed(33)

    mnist = mx.test_utils.get_mnist()

    batch_size = 100
    train_data = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
    val_data = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

    # # trying without contrastive pretraining
    # net = Net()
    #
    # # set the context on GPU is available otherwise CPU
    # gpus = mx.test_utils.list_gpus()
    # ctx = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]
    # net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
    # trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})
    #
    # train(net, train_data, epoch=32)
    #
    # validation_metric = validate(net, val_data)
    # print('validation acc: %s=%f' % validation_metric)
    # # assert validation_metric[1] > 0.98 # only if network is good

    # now with contrastive pretraining!
    net = Net()

    # set the context on GPU is available otherwise CPU
    gpus = mx.test_utils.list_gpus()
    net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})

    contrastive_pretrain(net, train_data, epoch=4)
    train(net, train_data, epoch=28)

    validation_metric = validate(net, val_data)
    # print('validation acc with pretraining: %s=%f' % validation_metric)

No answer.
I ended up using a conversion to numpy array and back:

def augment_monochrome(x,
                       joint_transform,
                       noise_ampl=.2,
                       opencv_gpu_fix=True):
    print('augment_monochrome call', end='...')
    if opencv_gpu_fix:
        x_aug0 = x.as_in_context(mx.cpu())
    else:
        x_aug0 = x
    x_aug = F.repeat(x_aug0, repeats=3, axis=1)
    x_aug_ = F.stack(*[F.swapaxes(
        (1 + noise_ampl * np.random.normal()) * joint_transform(
            F.swapaxes(x_aug[i, ...], 0, 2)
        ) + noise_ampl * np.random.normal(), 0, 2
    ) for i in range(batch_size)])
    if opencv_gpu_fix:
        x_aug = x_aug_.as_in_context(ctx)
    else:
        x_aug = x_aug_
    x_aug = x_aug[:, :1, :, :]
    x_aug = F.clip(x_aug, 0., 1.)
    print('augment_monochrome finish', end='...')
    return x_aug

slow, but it’s honest work…