fix: replace add_identity by add_cast for type cast by junstar92 · Pull Request #3563 · pytorch/TensorRT

Description

This PR updates the type_cast helper function to ensure compatibility with TensorRT's strongly typed network mode.

type_cast used add_identity() followed by set_output_type() to perform the data type cast. However, in strongly typed mode, calling set_output_type() on the identity layer causes an error below:

ILayer::setOutputType: Error Code 3: API Usage Error (Parameter check failed, condition: !mNetwork->usingStronglyTyped(). INetworkLayer::setOutputType cannot be called for a strongly typed network.)
[graphShapeAnalyzer.cpp::checkCalculationStatusSanity::1962] Error Code 2: Internal Error (Assertion !isInFlight(p.second.symbolicRep) failed. )

type_cast is called by expand function in torch_tensorrt/dynamo/conversion/impl/slice/ops.py with dynamic dimension index.

input_t = prepend_ones(
ctx.net,
input_t,
name + "_expand_broadcast",
shape_rank - initial_tensor_rank,
)

The following code snippet reproduces the error:

import torch
import torch_tensorrt
from torch.export._trace import _export
from torch_tensorrt.dynamo._compiler import CompilationSettings
from torch_tensorrt.dynamo.conversion import TRTInterpreter
from torch_tensorrt.dynamo.lowering import get_decompositions


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.visual = torch.nn.Linear(10, 10)

    def forward(self, input: torch.Tensor):
        return input.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0)


model = Model().to("cuda")
x = torch.randn(1, 40).to("cuda")
ep = _export(model, (x,))
ep = ep.run_decompositions(get_decompositions(False))
gm = ep.module()


interpreter = TRTInterpreter(
    gm,
    [torch_tensorrt.Input(name="input", min_shape=(1, 40), opt_shape=(4, 40), max_shape=(8, 40), dtype=torch.float32)],
    compilation_settings=CompilationSettings(use_explicit_typing=True),
)
results = interpreter.run()

To address this, the function now uses add_cast() to explicitly insert a cast layer that converts the input tensor to the desired cast_type.

If there was a specific reason for using add_identity(), please let me know, as this change assumes that the identity layer was not essential beyond type casting.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified