fix: Repair invalid schema arising from lowering pass by gs-olive · Pull Request #1786 · pytorch/TensorRT
- When `remove_unnecessary_casts` replaces both tensors in an `aten::div` call with the corresponding scalars, and the rounding mode (third argument) is specified, the schema becomes invalid - To avoid this, we intercept the call and split the result based on what the third argument in the input was, either delegating the result to `aten::Int(aten::div(...))`, for "trunc" or `aten::floordiv(...)` for "floor" - Update lowering pass to incorporate the above change, with appropriate documentation - Add testing to verify catching and correctly handling these edge cases