fix/feat: Add lowering pass to resolve most `aten::Int.Tensor` uses by gs-olive · Pull Request #1937 · pytorch/TensorRT
gs-olive
changed the title
fix/feat: Add lowering pass to resolve aten::Int.Tensor
fix/feat: Add lowering pass to resolve most aten::Int.Tensor invocations
gs-olive
changed the title
fix/feat: Add lowering pass to resolve most
fix/feat: Add lowering pass to resolve most aten::Int.Tensor invocationsaten::Int.Tensor uses
- Implement lowering pass which detects canonical `aten::Int.Tensor` cases and recursively replaces input Value pointers until all 0D tensors have been resolved to their scalar components - Lowering pass is specialized to replacing strictly integer-typed Value pointers and can only trace through aten::mul and aten::floor_divide operators, which are two of the most common cases of use - Lowering pass traverses the graph until one of three base cases are encountered (or an invalid Value type is detected). These cases are `prim::NumToTensor`, `prim::Constant` (0D tensor), or simple integers. It then replaces the child nodes with the integer equivalents of the produced Tensors - Added extensive testing of new capabilities for accuracy, robustness, and functionality
gs-olive
deleted the
replace_aten_int_schema
branch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters