Policy Gradient

class PolicyGradient(model, lr)[source]

Bases: parl.core.paddle.algorithm.Algorithm

__init__(model, lr)[source]

Policy gradient algorithm

Parameters
  • model (parl.Model) – model defining forward network of policy.

  • lr (float) – learning rate.

learn(obs, action, reward)[source]

Update model with policy gradient algorithm

Parameters
  • obs (paddle tensor) – shape of (batch_size, obs_dim)

  • action (paddle tensor) – shape of (batch_size, 1)

  • reward (paddle tensor) – shape of (batch_size, 1)

Returns

shape of (1)

Return type

loss (paddle tensor)

predict(obs)[source]

Predict the probability of actions

Parameters

obs (paddle tensor) – shape of (obs_dim,)

Returns

shape of (action_dim,)

Return type

prob (paddle tensor)