feat: data parallel inference examples by bowang007 · Pull Request #2805 · pytorch/TensorRT
--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:29:27.054073+00:00 +++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py 2024-05-02 00:31:18.785078+00:00 @@ -13,12 +13,26 @@ distributed_state = PartialState() model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device) -model.forward = torch.compile(model.forward, backend="torch_tensorrt", options={"truncate_long_and_double": True, "enabled_precisions": {torch.float16}, "debug": True}, dynamic=False,) +model.forward = torch.compile( + model.forward, + backend="torch_tensorrt", + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float16}, + "debug": True, + }, + dynamic=False, +) with distributed_state.split_between_processes([input_id1, input_id2]) as prompt: cur_input = torch.clone(prompt[0]).to(distributed_state.device) - gen_tokens = model.generate(cur_input, do_sample=True, temperature=0.9, max_length=100,) + gen_tokens = model.generate( + cur_input, + do_sample=True, + temperature=0.9, + max_length=100, + ) gen_text = tokenizer.batch_decode(gen_tokens)[0]