PARL is a flexible, distributed and object-oriented programming reinforcement learning framework.

Features

Object Oriented Programming

Distributed Training

class MLPModel(parl.Model):
  def __init__(self, act_dim):
    self.fc1 = layers.fc(size=10)
    self.fc2 = layers.fc(size=act_dim)

  def forward(self, obs):
    out = self.fc1(obs)
    out = self.fc2(out)
    return out

model = MLPModel()
target_model = copy.deepcopy(model)
# Absolute multi-thread programming
# witout the GIL limitation

@parl.remote_class
class HelloWorld(object):
    def sum(self, a, b):
        return a + b

parl.connect('localhost:8003')
obj = HelloWorld()
ans = obj.sum(a, b)

Abstractions

_images/abstractions.png
PARL aims to build an agent for training algorithms to perform complex tasks.
The main abstractions introduced by PARL that are used to build an agent recursively are the following:
  • Model is abstracted to construct the forward network which defines a policy network or critic network given state as input.

  • Algorithm describes the mechanism to update parameters in the model and often contains at least one model.

  • Agent, a data bridge between the environment and the algorithm, is responsible for data I/O with the outside environment and describes data preprocessing before feeding data into the training process.

Installation

Dependencies

  • Python 2.7 or 3.5+.

  • PaddlePaddle >=1.5.1 (Optional, if you only want to use APIs related to parallelization alone)

Install

PARL is distributed on PyPI and can be installed with pip:

pip install parl

or install from source:

pip install --upgrade git+https://github.com/PaddlePaddle/PARL.git

Features

1. Reproducible

We provide algorithms that reproduce stably the results of many influential reinforcement learning algorithms.

2. Large Scale

Ability to support high-performance parallelization of training with thousands of CPUs and multi-GPUs.

3. Reusable

Algorithms provided in the repository can be directly adapted to new tasks by defining a forward network and training mechanism will be built automatically.

4. Extensible

Build new algorithms quickly by inheriting the abstract class in the framework.

Getting Started

Goal of this tutorial:

  • Understand PARL’s abstraction at a high level

  • Train an agent to solve the Cartpole problem with Policy Gradient algorithm

This tutorial assumes that you have a basic familiarity of policy gradient.

Model

First, let’s build a Model that predicts an action given the observation. As an objective-oriented programming framework, we build models on the top of parl.Model and implement the forward function.

Here, we construct a neural network with two fully connected layers.

import parl
from parl import layers

class CartpoleModel(parl.Model):
    def __init__(self, act_dim):
        act_dim = act_dim
        hid1_size = act_dim * 10

        self.fc1 = layers.fc(size=hid1_size, act='tanh')
        self.fc2 = layers.fc(size=act_dim, act='softmax')

    def forward(self, obs):
        out = self.fc1(obs)
        out = self.fc2(out)
        return out

Algorithm

Algorithm will update the parameters of the model passed to it. In general, we define the loss function in Algorithm. In this tutorial, we solve the benchmark Cartpole using the Policy Graident algorithm, which has been implemented in our repository. Thus, we can simply use this algorithm by importting it from parl.algorithms.

We have also published various algorithms in PARL, please visit this page for more detail. For those who want to implement a new algorithm, please follow this tutorial.

model = CartpoleModel(act_dim=2)
algorithm = parl.algorithms.PolicyGradient(model, lr=1e-3)

Note that each algorithm should have two functions implemented:

  • learn

    updates the model’s parameters given transition data

  • predict

    predicts an action given current environmental state.

Agent

Now we pass the algorithm to an agent, which is used to interact with the environment to generate training data. Users should build their agents on the top of parl.Agent and implement four functions:

  • build_program

    define programs of fluid. In general, two programs are built here, one for prediction and the other for training.

  • learn

    preprocess transition data and feed it into the training program.

  • predict

    feed current environmental state into the prediction program and return an exectuive action.

  • sample

    this function is usually used for exploration, fed with current state.

class CartpoleAgent(parl.Agent):
    def __init__(self, algorithm, obs_dim, act_dim):
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        super(CartpoleAgent, self).__init__(algorithm)

    def build_program(self):
        self.pred_program = fluid.Program()
        self.train_program = fluid.Program()

        with fluid.program_guard(self.pred_program):
            obs = layers.data(
                name='obs', shape=[self.obs_dim], dtype='float32')
            self.act_prob = self.alg.predict(obs)

        with fluid.program_guard(self.train_program):
            obs = layers.data(
                name='obs', shape=[self.obs_dim], dtype='float32')
            act = layers.data(name='act', shape=[1], dtype='int64')
            reward = layers.data(name='reward', shape=[], dtype='float32')
            self.cost = self.alg.learn(obs, act, reward)

    def sample(self, obs):
        obs = np.expand_dims(obs, axis=0)
        act_prob = self.fluid_executor.run(
            self.pred_program,
            feed={'obs': obs.astype('float32')},
            fetch_list=[self.act_prob])[0]
        act_prob = np.squeeze(act_prob, axis=0)
        act = np.random.choice(range(self.act_dim), p=act_prob)
        return act

    def predict(self, obs):
        obs = np.expand_dims(obs, axis=0)
        act_prob = self.fluid_executor.run(
            self.pred_program,
            feed={'obs': obs.astype('float32')},
            fetch_list=[self.act_prob])[0]
        act_prob = np.squeeze(act_prob, axis=0)
        act = np.argmax(act_prob)
        return act

    def learn(self, obs, act, reward):
        act = np.expand_dims(act, axis=-1)
        feed = {
            'obs': obs.astype('float32'),
            'act': act.astype('int64'),
            'reward': reward.astype('float32')
        }
        cost = self.fluid_executor.run(
            self.train_program, feed=feed, fetch_list=[self.cost])[0]
        return cost

Start Training

First, let’s build an agent. As the code shown below, we usually build a model, an algorithm and finally agent.

model = CartpoleModel(act_dim=2)
alg = parl.algorithms.PolicyGradient(model, lr=1e-3)
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=2)

Then we use this agent to interact with the environment, and run around 1000 episodes for training, after which this agent can solve the problem.

def run_episode(env, agent, train_or_test='train'):
    obs_list, action_list, reward_list = [], [], []
    obs = env.reset()
    while True:
        obs_list.append(obs)
        if train_or_test == 'train':
            action = agent.sample(obs)
        else:
            action = agent.predict(obs)
        action_list.append(action)

        obs, reward, done, info = env.step(action)
        reward_list.append(reward)

        if done:
            break
    return obs_list, action_list, reward_list

env = gym.make("CartPole-v0")
for i in range(1000):
      obs_list, action_list, reward_list = run_episode(env, agent)
      if i % 10 == 0:
          logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))

      batch_obs = np.array(obs_list)
      batch_action = np.array(action_list)
      batch_reward = calc_discount_norm_reward(reward_list, GAMMA)

      agent.learn(batch_obs, batch_action, batch_reward)
      if (i + 1) % 100 == 0:
          _, _, reward_list = run_episode(env, agent, train_or_test='test')
          total_reward = np.sum(reward_list)
          logger.info('Test reward: {}'.format(total_reward))

Summary

_images/performance.gif _images/quickstart.png

In this tutorial, we have shown how to build an agent step-by-step to solve the Cartpole problem.

The whole training code could be found here. Have a try quickly by running several commands:

# Install dependencies
pip install paddlepaddle

pip install gym
git clone https://github.com/PaddlePaddle/PARL.git
cd PARL
pip install .

# Train model
cd examples/QuickStart/
python train.py

Create Customized Algorithms

Goal of this tutorial:

  • Learn how to implement your own algorithms.

Overview

To build a new algorithm, you need to inherit class parl.Algorithm and implement three basic functions: sample, predict and learn.

Methods

  • __init__

    As algorithms update weights of the models, this method needs to define some models inherited from parl.Model, like self.model in this example. You can also set some hyperparameters in this method, like learning_rate, reward_decay and action_dimension, which might be used in the following steps.

  • predict

    This function defines how to choose actions. For instance, you can use a policy model to predict actions.

  • sample

    Based on predict method, sample generates actions with noises. Use this method to do exploration if needed.

  • learn

    Define loss function in learn method, which will be used to update weights of self.model.

Example: DQN

This example shows how to implement DQN algorithm based on class parl.Algorithm according to the steps mentioned above.

Within class DQN(Algorithm), we define the following methods:

  • __init__(self, model, act_dim=None, gamma=None, lr=None)

    We define self.model and self.target_model of DQN in this method, which are instances of class parl.Model. And we also set hyperparameters act_dim, gamma and lr here. We will use these parameters in learn method.

    def __init__(self,
                 model,
                 act_dim=None,
                 gamma=None,
                 lr=None):
        """ DQN algorithm
    
        Args:
            model (parl.Model): model defining forward network of Q function
            act_dim (int): dimension of the action space
            gamma (float): discounted factor for reward computation.
            lr (float): learning rate.
        """
        self.model = model
        self.target_model = copy.deepcopy(model)
    
        assert isinstance(act_dim, int)
        assert isinstance(gamma, float)
        assert isinstance(lr, float)
        self.act_dim = act_dim
        self.gamma = gamma
        self.lr = lr
    
  • predict(self, obs)

    We use the forward network defined in self.model here, which uses observations to predict action values directly.

    def predict(self, obs):
            """ use value model self.model to predict the action value
            """
            return self.model.value(obs)
    
  • learn(self, obs, action, reward, next_obs, terminal)

    learn method calculates the cost of value function according to the predict value and the target value. Agent will use the cost to update weights in self.model.

    def learn(self, obs, action, reward, next_obs, terminal):
        """ update value model self.model with DQN algorithm
        """
    
        pred_value = self.model.value(obs)
        next_pred_value = self.target_model.value(next_obs)
        best_v = layers.reduce_max(next_pred_value, dim=1)
        best_v.stop_gradient = True
        target = reward + (
            1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v
    
        action_onehot = layers.one_hot(action, self.act_dim)
        action_onehot = layers.cast(action_onehot, dtype='float32')
        pred_action_value = layers.reduce_sum(
            layers.elementwise_mul(action_onehot, pred_value), dim=1)
        cost = layers.square_error_cost(pred_action_value, target)
        cost = layers.reduce_mean(cost)
        optimizer = fluid.optimizer.Adam(self.lr, epsilon=1e-3)
        optimizer.minimize(cost)
        return cost
    
  • sync_target(self)

    Use this method to synchronize the weights in self.target_model with those in self.model. This is the step used in DQN algorithm.

    def sync_target(self, gpu_id=None):
        """ sync weights of self.model to self.target_model
        """
        self.model.sync_weights_to(self.target_model)
    

Save and Restore Parameters

Goal of this tutorial:

  • Learn how to save and restore parameters.

Example

Sometimes we need to save the parameters into a file and reuse them later on. PARL provides operators to save parameters to a file and restore parameters from a file easily. You only need several lines to implement this.

Here is a demonstration of usage:

agent = AtariAgent()
# save the parameters of agent to ./model_dir
agent.save('./model_dir')
# restore the parameters from ./model_dir to agent
agent.restore('./model_dir')

# restore the parameters from ./model_dir to another_agent
another_agent = AtariAgent()
another_agent.restore('./model_dir')

summary

Visualize the results with tensorboard.

add_scalar

Common used arguments:

  • summary.add_scalar(tag, scalar_value, global_step=None)
    • tag (string) – Data identifier

    • scalar_value (float or string/blobname) – Value to save

    • global_step (int) – Global step value to record

Example:

from parl.utils import summary

x = range(100)
for i in x:
    summary.add_scalar('y=2x', i * 2, i)

Expected result:

_images/add_scalar.jpg

add_histogram

Common used arguments:

  • summary.add_scalar(tag, scalar_value, global_step=None)
    • tag (string) – Data identifier

    • values (torch.Tensor, numpy.array, or string/blobname) – Values to build histogram

    • global_step (int) – Global step value to record

Example:

from parl.utils import summary
import numpy as np

for i in range(10):
    x = np.random.random(1000)
    summary.add_histogram('distribution centers', x + i, i)

Expected result:

_images/add_histogram.jpg

Overview

Easy-to-use

With a single @parl.remote_class decorator, users can implement parallel training easily, and do not have to care about stuff of multi-processes, network communication.

High performance

@parl.remote_class enable us to achieve real multi-thread computation efficiency without modifying our codes. As shown in figure (a), python’s original multi-thread computation performs poorly due to the limitation of the GIL, while PARL empowers us to realize real parallel computation efficiency.

Web UI for computation resources

PARL provides a web monitor to watch the status of any resources connected to the cluster. Users can view the cluster status at a WEB UI. It shows the detailed information for each worker(e.g, memory used) and each task submitted.

Supporting vairous frameworks

PARL for distributed training is compatible with any other frameworks, like tensorflow, pytorch and mxnet. By adding @parl.remote_class decorator to their codes, users can easily convert their codes to distributed computation.

Why PARL

High throughput

PARL uses a point-to-point connection for network communication in the cluster. Unlike other framework like RLlib which replies on redis for communication, PARL is able to achieve much higher throughput. The results can be found in figure (b). With the same implementation in IMPALA, PARL achieved an increase of 160% on data throughout over Ray(RLlib).

Automatic deployment

Unlike other parallel frameworks which fail to import modules from external file, PARL will automatically package all related files and send them to remote machines.
_images/comparison.png

Cluster Setup

Setup Command

This tutorial demonstrates how to set up a cluster.

To start a PARL cluster, we can execute the following two xparl commands:

xparl start --port 6006

This command starts a master node to manage computation resources and adds the local CPUs to the cluster. We use the port 6006 for demonstration, and it can be any available port.

Adding More Resources

Note

If you have only one machine, you can ignore this part.

If you would like to add more CPUs(computation resources) to the cluster, run the following command on other machines.

xparl connect --address localhost:6006

It starts a worker node that provides CPUs of the machine for the master. A worker will use all the CPUs by default. If you wish to specify the number of CPUs to be used, run the command with --cpu_num <cpu_num> (e.g.——cpu_num 10).

Note that the command xparl connect can be run at any time, at any machine to add more CPUs to the cluster.

Example

Here we give an example demonstrating how to use @parl.remote_class for parallel computation.

import parl

@parl.remote_class
class Actor(object):
    def hello_world(self):
        print("Hello world.")

    def add(self, a, b):
        return a + b

# Connect to the master node.
parl.connect("localhost:6006")

actor = Actor()
actor.hello_world()# no log in the current terminal, as the computation is placed in the cluster.
actor.add(1, 2)  # return 3

Shutdown the Cluster

run xparl stop at the machine that runs as a master node to stop the cluster processes. Worker nodes at different machines will exit automatically after the master node is stopped.

Further Reading

Now we know how to set up a cluster and use this cluster by simply adding @parl.remote_class.
In next_tutorial, we will show how this decorator help us implement the real multi-thread computation in Python, breaking the limitation of Python Global Interpreter Lock(GIL).

Implemented Algorithms

Policy Gradient

DQN

DDPG

PPO

IMPALA

A3C

parl.Model

parl.Algorithm

parl.Agent

Overview

EvoKit 是一个集合了多种进化算法、兼容多种类预测框架的进化算法库,主打 快速上线验证

_images/DeepES.gif

特性

1. 多种进化算法支持。 支持高斯采样、CMA、GA等算法,更多算法持续接入中。

2. 主流优化器支持。 支持SGD/Momentum/Adam等多个主流优化器,有效提升算法收敛效率。

3. 一站式上线。 整合了线上采样和线下更新流程, 提供Bcloud/Cmake等编译方式, 助力快速上线。

4. 深度学习框架全系列兼容。 裸写的网络,paddle/lego/Torch等深度学习框架,EvoKit都支持。

5. 同步/异步更新方式。 支持多个采样模型/多份采样数据异步更新,完美契合业务场景。

minimal example

本教程的目标: 演示如何通过EvoKit库来解决经典的CartPole 问题。

本教程假定读者曾经使用过PaddlePaddle, 了解基本的进化算法迭代流程。

CartPole 介绍

CartPole又叫倒立摆。小车上放了一根杆,杆会因重力而倒下。为了不让杆倒下,我们要通过移动小车,来保持其是直立的。如下图所示。 在每一个时间步,模型的输入是一个4维的向量,表示当前小车和杆的状态,模型输出的信号用于控制小车往左或者右移动。当杆没有倒下的时候,每个时间步,环境会给1分的奖励;当杆倒下后,环境不会给任何的奖励,游戏结束。

_images/performance.gif

step1: 生成预测网络

根据上面的环境介绍,我们需要构造一个神经网络,输入为4维的向量,输出为2维的概率分布向量(表示左/右)移动的概率。 在这里,我们使用Paddle来实现预测网络,并保存到本地。

from paddle import fluid

def net(obs, act_dim):
    hid1 = fluid.layers.fc(obs, size=20)
    prob = fluid.layers.fc(hid1, size=act_dim, act='softmax')
    return prob

if __name__ == '__main__':
    obs_dim = 4
    act_dim = 2
    obs = fluid.layers.data(name="obs", shape=[obs_dim], dtype='float32')
    prob = net(obs, act_dim)

    exe = fluid.Executor(fluid.CPUPlace())
    exe.run(fluid.default_startup_program())
    fluid.io.save_inference_model(
        dirname='init_model',
        feeded_var_names=['obs'],
        target_vars=[prob],
        params_filename='params',
        model_filename='model',
        executor=exe)

step2: 构造ESAgent

  • 调用 load_config 加载配置文件。

  • 调用 load_inference_model 函数加载模型参数。

  • 调用 init_solver 初始化solver。

配置文件主要是用于指定进化算法类型(比如Gaussian或者CMA),使用的optimizer类型(Adam或者SGD)。

ESAgent agent = ESAgent();
agent.load_config(config);
agent.load_inference_model(model_dir);
agent.init_solver();

// 附:EvoKit配置项示范
solver {
    type: BASIC_ES
    optimizer { // 线下Adam更新
        type: ADAM
        base_lr: 0.05
        adam {
            beta1: 0.9
            beta2: 0.999
            epsilon: 1e-08
        }
    }
    sampling { // 线上高斯采样
        type: GAUSSIAN_SAMPLING
        gaussian_sampling {
            std: 0.5
            cached: true
            seed: 1024
            cache_size : 100000
        }
    }
}

step3: 生成用于采样的Agent

主要关注三个接口:

  • 调用 clone 生成一个用于sampling的agent。

  • 调用 add_noise 给这个agent的参数空间增加噪声,同时返回该噪声对应的唯一信息,这个信息得记录在log中,用于线下更新。

  • 调用 predict 提供预测接口。

auto sampling_agent = agent.clone();
auto sampling_info = sampling_agent.add_noise();
sampling_agent.predict(feature);

step4: 用采样的数据更新模型参数

用户提供两组数据:

  • 采样参数过程中用于线下复现采样噪声的sampling_info

  • 扰动参数后,新参数的评估结果

agent.update(sampling_infos, rewards);

主代码以及注释

以下的代码演示通过多线程同时采样, 提升解决问题的效率。

int main(int argc, char* argv[]) {
    std::vector<CartPole> envs;
    // 构造10个环境,用于多线程训练
    for (int i = 0; i < ITER; ++i) {
        envs.push_back(CartPole());
    }

    // 初始化ESAgent
    std::string model_dir = "./demo/cartpole/init_model";
    std::string config_path = "./demo/cartpole/config.prototxt";
    std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
    agent->load_config(config_path); // 加载配置

    agent->load_inference_model(FLAGS_model_dir); // 加载初始预测模型
    agent->init_solver(); // 初始化solver,注意要在load_inference_model后执行

    // 生成10个agent用于同时采样
    std::vector<std::shared_ptr<ESAgent>> sampling_agents;
    for (int i = 0; i < ITER; ++i) {
        sampling_agents.push_back(agent->clone());
    }

    std::vector<SamplingInfo> sampling_infos;
    std::vector<float> rewards(ITER, 0.0f);
    sampling_infos.resize(ITER);
    omp_set_num_threads(10);

    // 共迭代100轮
    for (int epoch = 0; epoch < 100; ++epoch) {
        #pragma omp parallel for schedule(dynamic, 1)
        for (int i = 0; i < ITER; ++i) {
            std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
            SamplingInfo sampling_info;
            sampling_agent->add_noise(sampling_info);
            float reward = evaluate(envs[i], sampling_agent);
            // 保存采样的sampling_info以及对应的评估结果reward
            sampling_infos[i] = sampling_info;
            rewards[i] = reward;
        }
        // 更新模型参数,注意:参数更新后会自动同步到sampling_agent中
        agent->update(sampling_infos, rewards);

        int reward = evaluate(envs[0], agent);
        LOG(INFO) << "Epoch:" << epoch << " Reward: " << reward; // 打印每一轮reward
    }
}

如何运行demo

  • 下载代码

    在icode上clone代码,我们的仓库路径是: baidu/nlp/deep-es TO DO: 修改库路径

  • 编译demo

    通过bcloud的云端集群编译即可,命令为: bb

  • 运行demo

    编译完成后,我们需要增加动态库查找路径:

    export LD_LIBRARY_PATH=./output/so/:$LD_LIBRARY_PATH

    运行demo: ./output/bin/cartpole/train

问题解决

在使用过程中有任何问题,请加hi群: 1692822 (PARL官方答疑群)进行咨询,开发同学会直接回答任何的使用问题。

Example for Online Products

本教程的目标: 演示通过EvoKit库上线后,如何迭代算法,更新模型参数。

在产品线中,线上无法实时拿到用户日志,经常是通过保存用户点击/时长日志,在线下根据用户数据更新模型,然后再推送到线上,完成算法的更新。 本教程继续围绕经典的CartPole环境,展示如何通过在线采样/离线更新的方式,来更新迭代ES算法。

demo的完整代码示例放在demp/online_example文件夹中。 TO DO: 文件夹

初始化solver

构造solver,对它初始化,并保存到文件。初始化solver仅需在开始时调用一次。

std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
agent->load_config(FLAGS_config_path);
agent->load_inference_model(FLAGS_model_dir);
agent->init_solver();
agent->save_solver(FLAGS_model_dir);

线上采样

加载模型和solver,记录线上采样返回的sampling_info以及评估的reward,并通过二进制的方式记录到log文件中。

std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
agent->load_config(FLAGS_config_path);
agent->load_inference_model(FLAGS_model_dir);
agent->load_solver(FLAGS_model_dir);

#pragma omp parallel for schedule(dynamic, 1)
for (int i = 0; i < ITER; ++i) {
    std::shared_ptr<ESAgent> sampling_agent = sampling_agents[i];
    SamplingInfo sampling_info;
    sampling_agent->add_noise(sampling_info);
    float reward = evaluate(envs[i], sampling_agent);
    sampling_infos[i] = sampling_info;
    rewards[i] = reward;
}

// save sampling information and log in binary fomrat
std::ofstream log_stream(FLAGS_log_path, std::ios::binary);
for (int i = 0; i < ITER; ++i) {
    std::string data;
    sampling_infos[i].SerializeToString(&data);
    int size = data.size();
    log_stream.write((char*) &rewards[i], sizeof(float));
    log_stream.write((char*) &size, sizeof(int));
    log_stream.write(data.c_str(), size);
}
log_stream.close();

线下更新

在加载好之前记录的log之后,调用 update 函数进行更新,然后通过 save_inference_modelsave_solver 函数保存更新后的参数到本地,推送到线上。

std::shared_ptr<ESAgent> agent = std::make_shared<ESAgent>();
agent->load_config(FLAGS_config_path);
agent->load_inference_model(FLAGS_model_dir);
agent->load_solver(FLAGS_model_dir);

// load training data
std::vector<SamplingInfo> sampling_infos;
std::vector<float> rewards(ITER, 0.0f);
sampling_infos.resize(ITER);
std::ifstream log_stream(FLAGS_log_path);
CHECK(log_stream.good()) << "[EvoKit] cannot open log: " << FLAGS_log_path;
char buffer[1000];
for (int i = 0; i < ITER; ++i) {
    int size;
    log_stream.read((char*) &rewards[i], sizeof(float));
    log_stream.read((char*) &size, sizeof(int));
    log_stream.read(buffer, size);
    buffer[size] = 0;
    std::string data(buffer);
    sampling_infos[i].ParseFromString(data);
}

// update model and save parameter
agent->update(sampling_infos, rewards);
agent->save_inference_model(FLAGS_updated_model_dir);
agent->save_solver(FLAGS_updated_model_dir);

主代码

将以上代码分别编译成可执行文件。

  • 初始化solver: init_solver

  • 线上采样: online_sampling

  • 线下更新: offline update

#------------------------init solver------------------------
./init_solver \
    --model_dir="./model_warehouse/model_dir_0" \
    --config_path="config.prototxt"


for ((epoch=0;epoch<200;++epoch));do
#------------------------online sampling------------------------
    ./online_sampling \
        --log_path="./sampling_log" \
        --model_dir="./model_warehouse/model_dir_$epoch" \
        --config_path="./config.prototxt"

#------------------------offline update------------------------
    next_epoch=$((epoch+1))
    ./offline_update \
        --log_path='./sampling_log' \
        --model_dir="./model_warehouse/model_dir_$epoch" \
        --updated_model_dir="./model_warehouse/model_dir_${next_epoch}" \
        --config_path="./config.prototxt"
done