feat: data parallel inference examples by bowang007 · Pull Request #2805 · pytorch/TensorRT

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py	2024-05-02 00:29:27.054073+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/data_parallel_gpt2.py	2024-05-02 00:31:18.785078+00:00
@@ -13,12 +13,26 @@

distributed_state = PartialState()

model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device)

-model.forward = torch.compile(model.forward, backend="torch_tensorrt", options={"truncate_long_and_double": True, "enabled_precisions": {torch.float16}, "debug": True}, dynamic=False,)
+model.forward = torch.compile(
+    model.forward,
+    backend="torch_tensorrt",
+    options={
+        "truncate_long_and_double": True,
+        "enabled_precisions": {torch.float16},
+        "debug": True,
+    },
+    dynamic=False,
+)

with distributed_state.split_between_processes([input_id1, input_id2]) as prompt:
    cur_input = torch.clone(prompt[0]).to(distributed_state.device)

-    gen_tokens = model.generate(cur_input, do_sample=True, temperature=0.9, max_length=100,)
+    gen_tokens = model.generate(
+        cur_input,
+        do_sample=True,
+        temperature=0.9,
+        max_length=100,
+    )
    gen_text = tokenizer.batch_decode(gen_tokens)[0]