fix/feat: Add lowering pass to resolve most `aten::Int.Tensor` uses by gs-olive · Pull Request #1937 · pytorch/TensorRT

@gs-olive gs-olive changed the title fix/feat: Add lowering pass to resolve aten::Int.Tensor fix/feat: Add lowering pass to resolve most aten::Int.Tensor invocations

May 19, 2023

@gs-olive gs-olive changed the title fix/feat: Add lowering pass to resolve most aten::Int.Tensor invocations fix/feat: Add lowering pass to resolve most aten::Int.Tensor uses

May 19, 2023

@gs-olive

- Implement lowering pass which detects canonical `aten::Int.Tensor`
cases and recursively replaces input Value pointers until all 0D tensors
have been resolved to their scalar components
- Lowering pass is specialized to replacing strictly integer-typed Value pointers
and can only trace through aten::mul and aten::floor_divide operators,
which are two of the most common cases of use
- Lowering pass traverses the graph until one of three base cases are
encountered (or an invalid Value type is detected). These cases are
`prim::NumToTensor`, `prim::Constant` (0D tensor), or simple integers.
It then replaces the child nodes with the integer equivalents of the
produced Tensors
- Added extensive testing of new capabilities for accuracy, robustness,
and functionality

narendasan

gs-olive

@gs-olive

- Edit in favor of `c10::optional` type usage

narendasan

@gs-olive gs-olive deleted the replace_aten_int_schema branch

May 30, 2023 21:54

narendasan pushed a commit that referenced this pull request

Jun 2, 2023

narendasan pushed a commit that referenced this pull request

Jun 2, 2023

narendasan pushed a commit that referenced this pull request

Jun 3, 2023