Policy Gradient

class PolicyGradient(model, lr)[源代码]

基类:parl.core.paddle.algorithm.Algorithm

__init__(model, lr)[源代码]

Policy gradient algorithm

参数
  • model (parl.Model) – model defining forward network of policy.

  • lr (float) – learning rate.

learn(obs, action, reward)[源代码]

Update model with policy gradient algorithm

参数
  • obs (paddle tensor) – shape of (batch_size, obs_dim)

  • action (paddle tensor) – shape of (batch_size, 1)

  • reward (paddle tensor) – shape of (batch_size, 1)

返回

shape of (1)

返回类型

loss (paddle tensor)

predict(obs)[源代码]

Predict the probability of actions

参数

obs (paddle tensor) – shape of (obs_dim,)

返回

shape of (action_dim,)

返回类型

prob (paddle tensor)