Deep Deterministic Policy Gradient (DDPG) in Gluon

Is there any example implementation of Deep Deterministic Policy Gradient (DDPG) for the Gluon API? If there isn’t one, can someone help me with the implementation?
I tried to implement it by myself but I got stuck at the point where I have to update my actor network.

I implemented the following training routine:

if do_training():
    # Sample random batch from replay buffer
    states, actions, rewards, next_states, terminals  = replay_buffer.sample(batch_size=BATCH_SIZE)
    
    # Calculate target y with actor and critic target networks
    target_actions = actor_target(next_states)
    target_qvalues = critic_target(next_states, target_actions)
    y = rewards + (1.0 - terminals) * DISCOUNT_FACTOR * target_qvalues

    # Update critic network by minimizing reward prediction error
    with autograd.record():
        qvalues = critic(states, actions)
        loss = l2_loss(qvalues, y)
    loss.backward()
    trainer_critic.step(BATCH_SIZE)  # actual update with gluon.trainer

    # Let actor propose particular action for given state 
    actor_action = actor(states)
    actor_action.attach_grad()

    # Compute Q(state, action) and backpropagate w.r.t. actions
    with autograd.record():
        qvalues = critic(states, actor_action)
    qvalues.backward()
    action_gradients = actor_action.grad

My first problem is, that all the gradients of action_gradients are the same for the whole batch, so I am not sure if this is correct. My second problem is that I do not know how to proceed with the algorithm. How can I update the actor weights with the calculated gradients from the critic network?

Hi @nifeles,

You should have a trainer for the critic network and another for the actor network, and these are used to update the parameters of the networks. You don’t seem to run the actor while using autograd, so the gradients won’t be passed all the way back to the actor parameters. Also is qvalues actually a loss you’re trying to minimize?

2 Likes

I see you’re trying to manually grab the gradients on the action, presumably so you can mirror the paper where they get those gradients, and then multiply them by gradients for when you run the actor forward. While you can do it like that, in effect you’re just doing backprop manually, when autograd could do it for you. That is, modify to:

with autograd.record():
        qvalues = critic(states, actor(states))
        actor_loss = -1 * qvalues  # if we want to maximize q-values with a minimizer, you need to multiply by -1
actor_loss.backward()
actor_trainer.step(BATCH_SIZE)

Importantly for this to work though, you need to make sure that your actor trainer is only operating on the actor network parameters and not the critic parameters. Otherwise the optimization will affect the critic parameters which you don’t want to do in the actor update!

2 Likes

Yes, I tried to mirror the paper and I wanted to calculate J = dQ/da dMu/dTheta manually. I did not understand that I can just maximize the Q-values here. Thank you both very much for your answers, you helped me a lot.