Extend to other RL(HF) algorithms ================================= Last updated: 02/25/2025. We already implemented the complete training pipeline of the PPO algorithms. To extend to other algorithms, we analyze the high-level principle to use verl and provide a tutorial to implement the DPO algorithm. Users can follow the similar paradigm to extend to other RL algorithms. .. note:: **Key ideas**: Single process drives multi-process computation and data communication. Overall Approach ---------------- Step 1: Consider what multi-machine multi-GPU computations are needed for each model, such as ``generate_sequence`` , ``compute_log_prob`` and ``update_policy`` in the actor_rollout model. Implement distributed single-process-multiple-data (SPMD) computation and encapsulate them into APIs Step 2: Based on different distributed scenarios, including FSDP and 3D parallelism in Megatron-LM, implement single-process control of data interaction among multi-process computations. Step 3: Utilize the encapsulated APIs to implement the control flow Example: Online DPO ------------------- We use verl to implement a simple online DPO algorithm. The algorithm flow of Online DPO is as follows: 1. There is a prompt (rollout) generator which has the same weight as the actor model. After a batch of prompts are fed into the generator, it generates N responses for each prompt. 2. Send all the prompts + responses to a verifier for scoring, which can be reward model or a rule-based function. Then sort them in pairs to form a training batch. 3. Use this training batch to train the actor model using DPO. During the process, a reference policy is needed. Step 1: What are the multi-machine multi-GPU computations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **Sample Generator** Implementation details: .. code:: python from verl.single_controller.base import Worker from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool import ray @ray.remote class SampleGenerator(Worker): def __init__(self, config): super().__init__() self.config = config def generate_sequences(self, data): pass Here, ``SampleGenerator`` can be viewed as a multi-process pulled up by ``torchrun``, with each process running the same code (SPMD). ``SampleGenerator`` needs to implement a ``generate_sequences`` API for the control flow to call. The implementation details inside can use any inference engine including vllm, sglang and huggingface. Users can largely reuse the code in verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py and we won't go into details here. **ReferencePolicy inference** API: compute reference log probability .. code:: python from verl.single_controller.base import Worker import ray @ray.remote class ReferencePolicy(Worker): def __init__(self): super().__init__() self.model = Model() def infer(self, data): return self.model(data) **Actor update** API: Update actor model parameters .. code:: python from verl.single_controller.base import Worker import ray @ray.remote class DPOActor(Worker): def __init__(self): super().__init__() self.model = Model() self.model = FSDP(self.model) # or other distributed strategy self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) self.loss_fn = xxx def update(self, data): self.optimizer.zero_grad() logits = self.model(data) loss = self.loss_fn(logits) loss.backward() self.optimizer.step() **Notes: How to distinguish between control processes and distributed computation processes** ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - Control processes are generally functions directly decorated with ``@ray.remote`` - Computation processes are all wrapped into a ``RayWorkerGroup``. Users can reuse most of the distribtued computation logics implemented in PPO algorithm, including FSDP and Megatron-LM backend in verl/verl/trainer/ppo. Step 2: Based on different distributed scenarios, implement single-process control of multi-process data interaction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ **The core problem to solve here is how a single process sends data to multiple processes, drives multi-process computation, and how the control process obtains the results of multi-process computation.** First, we initialize the multi-process ``WorkerGroup`` in the control process. .. code:: python @ray.remote(num_cpus=1) def main_task(config): # construct SampleGenerator resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) # put SampleGenerator onto resource pool worker_group = RayWorkerGroup(resource_pool, ray_cls) # construct reference policy As we can see, in the control process, multiple processes are wrapped into a ``RayWorkerGroup``. Inside this ``WorkerGroup``, there is a ``self._workers`` member, where each worker is a RayActor (https://docs.ray.io/en/latest/ray-core/actors.html) of SampleGenerator. ray_trainer.md also provide an implementation of ``MegatronRayWorkerGroup``. Assuming the model is distributed using FSDP, and there is a batch of data on the control process, for data parallelism, the underlying calling process is: .. code:: python data = xxx data_list = data.chunk(dp_size) output = [] for d in data_list: # worker_group._workers[i] is a SampleGenerator output.append(worker_group._workers[i].generate_sequences.remote(d)) output = ray.get(output) output = torch.cat(output) Single process calling multiple processes involves the following 3 steps: 1. Split the data into DP parts on the control process. 2. Send the data to remote, call the remote computation through RPC, and utilize multi-process computation. 3. Obtain the computation results of each worker on the control process and merge them. Frequently calling these 3 steps on the controller process greatly hurts code readability. **In verl, we have abstracted and encapsulated these 3 steps, so that the worker's method + dispatch + collect can be registered into the worker_group** .. code:: python from verl.single_controller.base.decorator import register def dispatch_data(worker_group, data): return data.chunk(worker_group.world_size) def collect_data(worker_group, data): return torch.cat(data) dispatch_mode = { 'dispatch_fn': dispatch_data, 'collect_fn': collect_data } @register(dispatch_mode=dispatch_mode) def generate_sequences(self, data): pass In this way, we can directly call the method inside the worker through the ``worker_group`` on the control (driver) process (which is a single process): .. code:: python output = worker_group.generate_sequences(data) This single line includes data splitting, data distribution and computation, and data collection. Furthermore, the model parallelism size of each model is usually fixed, including dp, tp, pp. So for these common distributed scenarios, we have pre-implemented specific dispatch and collect methods,in `decorator.py `_, which can be directly used to wrap the computations. .. code:: python from verl.single_controller.base.decorator import register, Dispatch @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, data: DataProto) -> DataProto: pass Here it requires the data interface to be ``DataProto``. Definition of ``DataProto`` is in `protocol.py `_. Step 3: Main training loop ~~~~~~~~~~~~~~~~~~~~~~~~~~ With the above training flows, we can implement the algorithm's control flow. It is recommended that ``main_task`` is also a ray remote process. .. code:: python @ray.remote(num_cpus=1) def main_task(config): # construct SampleGenerator resource_pool = RayResourcePool(process_on_nodes=[8] * 2) # 16 GPUs ray_cls = RayClassWithInitArgs(SampleGenerator, config=config) # put SampleGenerator onto resource pool sample_gen = RayWorkerGroup(resource_pool, ray_cls) # construct reference policy ray_cls = RayClassWithInitArgs(ReferencePolicy) ref_policy = RayWorkerGroup(resource_pool, ray_cls) # construct actor ray_cls = RayClassWithInitArgs(DPOActor) dpo_policy = RayWorkerGroup(resource_pool, ray_cls) dataloader = DataLoader() for data in dataloader: # generate data data = sample_gen.generate_sequences(data) # generate scores for each data data = generate_scores(data) # generate pairwise data using scores data = generate_pairwise_data(data) # generate ref_log_prob data.batch['ref_log_prob'] = ref_policy.infer(data) # update using dpo dpo_policy.update(data) # logging Here, different ``WorkerGroups`` can be placed in the same resource pool or in different resource pools using ``create_colocated_worker_cls`` similar as in `ray_trainer.py `_.