Fix grid_sample by HolyWu · Pull Request #3340 · pytorch/TensorRT
This PR fixes two issues in grid_sample.
import os import torch import torch.nn.functional as F import torch_tensorrt os.environ["CI_BUILD"] = "1" class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: return F.grid_sample(x, grid, mode="bilinear", align_corners=False) with torch.inference_mode(): model = MyModule().eval().cuda() inputs = [torch.randn(1, 3, 224, 224, device="cuda"), torch.randn(1, 224, 224, 2, device="cuda")] trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1) torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3) print("assert_close passed")
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler.default](args = (%x, %grid, 0, 0, False), kwargs = {}) return (grid_sampler,) DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {}) return (grid_sampler_2d,) DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {}) return (grid_sampler_2d,) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {}) return (grid_sampler_2d,) DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {}) return (grid_sampler_2d,) DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Supported Nodes: - torch.ops.aten.grid_sampler_2d.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph. INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Number of TensorRT-Accelerated Engines Generated: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Supported Nodes: - torch.ops.aten.grid_sampler_2d.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: All Nodes Supported DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0 Input shapes: [(1, 3, 224, 224), (1, 224, 224, 2)] graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {}) return grid_sampler_2d DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ()) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.FLOAT] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.float32)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node grid (kind: grid, args: ()) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: grid [shape=[1, 224, 224, 2], dtype=DataType.FLOAT] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node grid [grid] (Inputs: () | Outputs: (grid: (1, 224, 224, 2)@torch.float32)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /grid_sampler_2d (kind: aten.grid_sampler_2d.default, args: ('x <Node>', 'grid <Node>', '0 <int>', '0 <int>', 'False <bool>')) DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /grid_sampler_2d [aten.grid_sampler_2d.default] (Inputs: (x: (1, 3, 224, 224)@torch.float32, grid: (1, 224, 224, 2)@torch.float32, 0, 0, False) | Outputs: (grid_sampler_2d: (1, 3, 224, 224)@torch.float32)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('grid_sampler_2d <Node>',)) DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, 224, 224), dtype=DataType.FLOAT] INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (grid_sampler_2d: (1, 3, 224, 224)@torch.float32) | Outputs: (output: )) INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001957 INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine. INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.120683 INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 12180 bytes of Memory DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB DEBUG: [Torch-TensorRT] - Deserialization required 3953 microseconds. DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0 DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80 DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 0 DEBUG: [Torch-TensorRT] - - Runner scratch: 0 bytes INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB) DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled. DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0 DEBUG: [Torch-TensorRT] - Input binding name: grid has TensorRT binding index: 1, Torch binding index: 1 DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 2, Torch binding index: 2 DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine: Name: _run_on_acc_0_engine Inputs: [ id: 0 name: x shape: [1, 3, 224, 224] dtype: Float id: 1 name: grid shape: [1, 224, 224, 2] dtype: Float ] Outputs: [ id: 0 name: output0 shape: [1, 3, 224, 224] dtype: Float ] Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) Hardware Compatibility: Disabled Target Platform: windows_x86_64 DEBUG:torch_tensorrt.dynamo._DryRunTracker: ++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++ The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False) Graph Structure: Inputs: List[Tensor: (1, 3, 224, 224)@float32, Tensor: (1, 224, 224, 2)@float32] ... TRT Engine #1 - Submodule name: _run_on_acc_0 Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32, Tensor: (1, 224, 224, 2)@float32] Number of Operators in Engine: 1 Engine Outputs: List[Tensor: (1, 3, 224, 224)@float32] ... Outputs: List[Tensor: (1, 3, 224, 224)@float32] ------------------------- Aggregate Stats ------------------------- Average Number of Operators per TRT Engine: 1.0 Most Operators in a TRT Engine: 1 ********** Recommendations ********** - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s) - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s) DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0 DEBUG: [Torch-TensorRT] - Input shape changed None -> (1,3,224,224)(1,224,224,2) DEBUG: [Torch-TensorRT] - Input Name: x Shape: [1, 3, 224, 224] DEBUG: [Torch-TensorRT] - Input Name: grid Shape: [1, 224, 224, 2] DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [1, 3, 224, 224] Traceback (most recent call last): File "C:\Users\HolyWu\Downloads\test.py", line 25, in <module> torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3) File "C:\Python312\Lib\site-packages\torch\testing\_comparison.py", line 1530, in assert_close raise error_metas[0].to_error(msg) AssertionError: Tensor-likes are not close! Mismatched elements: 69917 / 150528 (46.4%) Greatest absolute difference: 3.082557439804077 at index (0, 2, 213, 105) (up to 0.005 allowed) Greatest relative difference: 280988.21875 at index (0, 1, 184, 207) (up to 0.005 allowed)
import os import torch import torch.nn.functional as F import torch_tensorrt os.environ["CI_BUILD"] = "1" class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: return F.grid_sample(x, grid, mode="bilinear", align_corners=True) with torch.inference_mode(): model = MyModule().eval().cuda() inputs = [torch.randn(1, 3, 224, 224, device="cuda"), torch.randn(1, 224, 224, 2, device="cuda")] trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1) torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3) print("assert_close passed")
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler.default](args = (%x, %grid, 0, 0, True), kwargs = {}) return (grid_sampler,) DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {}) return (cudnn_grid_sampler,) DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {}) return (cudnn_grid_sampler,) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {}) return (cudnn_grid_sampler,) DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph(): %x : [num_users=1] = placeholder[target=x] %grid : [num_users=1] = placeholder[target=grid] %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {}) return (cudnn_grid_sampler,) DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Supported Nodes: DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Unsupported or Excluded Nodes: - torch.ops.aten.cudnn_grid_sampler.default + Operator Count: 1 WARNING:torch_tensorrt.dynamo._compiler:0 supported operations detected in subgraph containing 1 computational nodes. Skipping this subgraph, since min_block_size was detected to be 1 assert_close passed