[Feat]: Support for CPU and Non-Tensor Data Storage of Yuanrong by dpj135 · Pull Request #107 · TransferQueue/TransferQueue
# Copyright 2025 The TransferQueue Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import logging import math import os import sys import time from pathlib import Path try: import torch_npu print("NPU device count:", torch_npu.npu.device_count()) except Exception as e: print("NPU not available:", e) import ray import torch from omegaconf import OmegaConf from tensordict import NonTensorData, TensorDict parent_dir = Path(__file__).resolve().parent.parent.parent sys.path.append(str(parent_dir)) from transfer_queue import ( # noqa: E402 AsyncTransferQueueClient, BatchMeta, SimpleStorageUnit, TransferQueueController, process_zmq_server_info, ) from transfer_queue.utils.utils import get_placement_group # noqa: E402 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["RAY_DEBUG"] = "1" ray.init() def compute_old_log_prob(data1, data2): time.sleep(3) return data1 def generate_sequences(data): time.sleep(3) return data class ActorRolloutRefWorker: def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client): # 1. Pull real data from the storage plane through client based on data_meta data = asyncio.run(data_system_client.async_get_data(data_meta)) logger.info(f"demo get data->generate_sequences {data}") output = generate_sequences(data["input_ids"]) output = TensorDict( { "generate_sequences_ids": output, "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]), "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]), }, batch_size=output.size(0), ) # 2. Write results back to the storage plane based on data_meta asyncio.run(data_system_client.async_put(data=output, metadata=data_meta)) data_meta.add_fields(output) logger.info("demo put data to storages done") return data_meta def actor_rollout_wg_compute_old_log_prob(self, data_meta, data_system_client): # 1. Pull real data from the storage plane through client based on data_meta data = asyncio.run(data_system_client.async_get_data(data_meta)) logger.info(f"demo get data->old_log_prob {data}") output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"]) output = TensorDict({"old_log_prob": output}, batch_size=output.size(0)) # 2. Write results back to the storage plane based on data_meta asyncio.run(data_system_client.async_put(data=output, metadata=data_meta)) data_meta.add_fields(output) logger.info("demo put data to storages done") return data_meta @ray.remote(resources={"NPU": 1}) class AsyncvLLMServer: def __init__(self, config, data_system_controller_info): self.config = config self.data_system_client = AsyncTransferQueueClient( client_id="AsyncvLLMServer", controller_info=data_system_controller_info, ) self.data_system_client.initialize_storage_manager(manager_type="YuanrongStorageManager", config=self.config) async def generate(self, data_meta): data = await self.data_system_client.async_get_data(data_meta) logger.info(f"demo get data->generate_sequences {data}") data = data["input_ids"] data += 1 await asyncio.sleep(3) output = TensorDict( { "generate_sequences_ids": data, "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]), "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(data.size(0))]), }, batch_size=data.size(0), ) await self.data_system_client.async_put(data=output, metadata=data_meta) logger.info("demo Async Server put data to storages done") return data_meta @ray.remote(num_cpus=1, resources={"NPU": 1}) class AsyncRolloutWorker: def __init__( self, config, data_system_controller_info, ): self.async_vllm_server = AsyncvLLMServer.remote( config, data_system_controller_info, ) async def generate_sequences(self, data_meta_chunk): tasks = [] for i in range(data_meta_chunk.size): # asyncio.create_task cannot directly call Ray Actor methods, # otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx) tasks.append(asyncio.create_task(self.generate(data_meta_chunk[i]))) data_metas = await asyncio.gather(*tasks) return BatchMeta.concat(data_metas) async def generate(self, data_meta): data_meta_new = await self.async_vllm_server.generate.remote(data_meta) return data_meta_new class RolloutManager: def __init__(self, config, data_system_storage_unit_infos, data_system_controller_info): self.config = config self.data_system_client = AsyncTransferQueueClient( client_id="RolloutManager", controller_info=data_system_controller_info, ) self.data_system_client.initialize_storage_manager(manager_type="YuanrongStorageManager", config=self.config) self.async_rollout_workers = [] num_workers = self.config.rollout_agent_num_workers for i in range(num_workers): self.async_rollout_workers.append(AsyncRolloutWorker.remote(config, data_system_controller_info)) def generate_sequences(self, data_meta): data_meta_chunkes = data_meta.chunk(len(self.async_rollout_workers)) data_metas = ray.get( [ worker.generate_sequences.remote(data_meta_chunk) for worker, data_meta_chunk in zip(self.async_rollout_workers, data_meta_chunkes, strict=True) ] ) batch_meta = BatchMeta.concat(data_metas) logger.info(f"batch_meta: {batch_meta}") return batch_meta class Trainer: def __init__(self, config): self.config = config self.data_system_client = self._initialize_data_system() self.actor_rollout_wg = ActorRolloutRefWorker() self.async_rollout_manager = RolloutManager( self.config, self.data_system_storage_unit_infos, self.data_system_controller_info, ) def _initialize_data_system(self): # 1. Initialize TransferQueueStorage total_storage_size = self.config.global_batch_size * self.config.num_global_batch * self.config.num_n_samples self.data_system_storage_units = {} storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1) for storage_unit_rank in range(self.config.num_data_storage_units): storage_node = SimpleStorageUnit.options( placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.num_data_storage_units)) self.data_system_storage_units[storage_unit_rank] = storage_node logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") # 2. Initialize TransferQueueController (single controller only) # Sampler usage instructions: # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler: # Option 1: Pass sampler class (will be instantiated automatically) # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) # Option 2: Pass sampler instance (if you need custom configuration) # grpo_sampler = GRPOGroupNSampler() # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) # Then use sampling_config in get_meta calls: # sampling_config={"n_samples_per_prompt": 4} self.data_system_controller = TransferQueueController.remote() logger.info("TransferQueueController has been created.") # 3. Prepare necessary information self.data_system_controller_info = process_zmq_server_info(self.data_system_controller) self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig # with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict tq_config.controller_info = self.data_system_controller_info tq_config.storage_unit_infos = self.data_system_storage_unit_infos self.config = OmegaConf.merge(tq_config, self.config) # 4. Create client self.data_system_client = AsyncTransferQueueClient( client_id="Trainer", controller_info=self.data_system_controller_info, ) self.data_system_client.initialize_storage_manager(manager_type="YuanrongStorageManager", config=self.config) # Note: The client contains ZMQ objects. Currently, we cannot transmit the same client instance # to multiple places, as this will cause serialization errors in Ray. # Workaround: If you need to use a client in multiple Ray actors or processes, create a separate # AsyncTransferQueueClient instance for each actor/process instead of sharing or transmitting the same instance. return self.data_system_client def fit(self): for epoch in range(1): train_dataloader = 1 for step in range(train_dataloader): input_ids = ( torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111], [200, 222], [300, 333]]) ) * (step + 1) input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0) prompt_batch = TensorDict( {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated}, batch_size=input_ids_repeated.size(0), ) asyncio.run(self.data_system_client.async_put(data=prompt_batch, partition_id=f"train_{step}")) logger.info("demo put prompts ok! ") time.sleep(5) batch_meta = asyncio.run( self.data_system_client.async_get_meta( data_fields=["input_ids", "attention_mask"], batch_size=self.config.global_batch_size * self.config.num_n_samples, partition_id=f"train_{step}", task_name="generate_sequences", ) ) logger.info(f"demo get meta {batch_meta}") # Simulate calling the generate sequences task of the worker group if not self.config.async_rollout_mode: batch_meta = self.actor_rollout_wg.actor_rollout_wg_generate_sequences( batch_meta, self.data_system_client ) else: batch_meta = self.async_rollout_manager.generate_sequences(batch_meta) log_prob_meta = asyncio.run( self.data_system_client.async_get_meta( data_fields=["input_ids", "attention_mask", "generate_sequences_ids"], batch_size=self.config.global_batch_size * self.config.num_n_samples, partition_id=f"train_{step}", task_name="compute_old_log_prob", ) ) logger.info(f"demo get log prob meta: {log_prob_meta}") # Simulate calling the compute old log prob task of the worker group old_log_prob_meta = self.actor_rollout_wg.actor_rollout_wg_compute_old_log_prob( log_prob_meta, self.data_system_client ) batch_meta = batch_meta.union(old_log_prob_meta) # Client notifies controller to clear data status, controller returns metadata; # Client then notifies the storage plane to clear based on metadata asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{step}")) logger.info("clear ok! ") logger.info("demo done!") # Cleanup resources self.data_system_client.close() return batch_meta if __name__ == "__main__": # NOTE: you may choose to set async_rollout_mode=True to test the async rollout mode that mimics # AgentLoopManager in verl config_str = """ global_batch_size: 8 num_global_batch: 1 num_data_storage_units: 2 async_rollout_mode: True rollout_agent_num_workers: 2 num_n_samples: 2 host: '127.0.0.1' port: 31511 client_name: 'YuanrongStorageClient' """ dict_conf = OmegaConf.create(config_str) trainer = Trainer(dict_conf) trainer.fit() ray.shutdown() print('demo finished!')