Fix upsample converter not properly registered by HolyWu · Pull Request #2683 · pytorch/TensorRT

Even though the operator is properly registered along with #2681 being applied, the operator is still decomposed into lower-level operators rather than converted using this converter, just like #2665 (comment). Adding aten.upsample_bilinear2d.default and aten.upsample_bilinear2d.vec to torch_disabled_decompositions doesn't help. Compiling the model under with torch.inference_mode() also doesn't help. At the end I find out that I have to remove these two lines and this line in PyTorch to bypass the decomposition and then this converter finally works.

DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return (upsample_bilinear2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return (upsample_bilinear2d,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return (upsample_bilinear2d,)
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(precision=torch.float16, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.upsample_bilinear2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.upsample_bilinear2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [(1, 3, 128, 128)]
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %upsample_bilinear2d : [num_users=1] = call_function[target=torch.ops.aten.upsample_bilinear2d.default](args = (%arg0_1, [256, 256], True, 2.0, 2.0), kwargs = {})
    return upsample_bilinear2d
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name arg0_1
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name __/upsample_bilinear2d
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Node meta name output
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.000980
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:01.093642
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 0 bytes of Memory
DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(precision=torch.float16, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_long_and_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.DEFAULT: 0>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, output_format='exported_program')

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 128, 128)@float16]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 128, 128)@float16]
     Number of Operators in Engine: 1
     Engine Outputs: Tensor: (1, 3, 256, 256)@float16
    ...
   Outputs: List[Tensor: (1, 3, 256, 256)@float16]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
WARNING: [Torch-TensorRT] - Using default stream in enqueue()/enqueueV2()/enqueueV3() may lead to performance issues due to additional cudaDeviceSynchronize() calls by TensorRT to ensure correct synchronizations. Please use non-default stream instead.