Refactor dynamo training backends by jansel · Pull Request #93409 · pytorch/pytorch

@@ -0,0 +1,145 @@ import logging import operator from collections import defaultdict from typing import Set
import torch
from torch.fx import GraphModule from torch.fx.passes.backends.cudagraphs import partition_cudagraphs from torch.multiprocessing.reductions import StorageWeakRef from torch.nn import Module from torch.utils._pytree import tree_map from .common import aot_autograd from .registry import register_backend
log = logging.getLogger(__name__)

def cloner(t): if isinstance(t, torch.Tensor): return t.clone() else: return t

class CudaGraphModule(Module): gm: GraphModule mutated_inputs: Set[int]
def __init__(self, gm, mutated_inputs): super().__init__() self.gm = gm self.mutated_inputs = mutated_inputs
warmed_up = False
# these are all None or all filled graph = None static_inputs = None static_outputs = None
# NB: we override __call__ as we don't need any nn.Module machinery # and to reduce overhead def __call__(self, *args): # TODO: once we've recorded here, we'd like to replace the __call__ # implementation with compiled bytecode that copies into static, replays # the cuda graph, then copies out. First condition is the hotpath, # needs optimizing if self.graph is not None: assert len(args) == len(self.static_inputs) for dst, src in zip(self.static_inputs, args): dst.copy_(src) self.graph.replay() for i in self.mutated_inputs: args[i].copy_(self.static_inputs[i]) return tree_map(cloner, self.static_outputs)
elif self.warmed_up: # record self.static_inputs = [x.clone() for x in args] self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): self.static_outputs = self.gm(*self.static_inputs) # NB: recording doesn't actually run the operations, so # now we immediately replay the graph to serve up the result self.graph.replay() for i in self.mutated_inputs: args[i].copy_(self.static_inputs[i]) return tree_map(cloner, self.static_outputs)
else: # warmup stream = torch.cuda.Stream() stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(stream): r = self.gm(*args) torch.cuda.current_stream().wait_stream(stream) self.warmed_up = True return r

# Interpreter versions of these passes can be found at # https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23

def find_input_mutations(g): def meta_fk(meta): return meta["val"] if "val" in meta else meta["fake_result"]
inputs = defaultdict(set) input_idx = 0 mutated_inputs = set() for n in g.nodes: if n.op == "placeholder": inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) input_idx += 1 elif n.op == "call_function": if n.target is operator.getitem: continue schema = n.target._schema for i, arg in enumerate(schema.arguments): if i < len(n.args): argument = n.args[i] else: if arg.name not in n.kwargs: continue argument = n.kwargs[arg.name] mut_arg = False if arg.alias_info: if arg.alias_info.is_write: mut_arg = True if mut_arg: # TODO: not correct for args that contain tensors in a struct # like list mutated_inputs |= inputs[ StorageWeakRef(meta_fk(argument.meta)._typed_storage()) ] # TODO: error on unrecognized nodes return mutated_inputs

# Mutates input graph def apply_cuda_graphs(gm): for n in gm.graph.nodes: if n.op == "call_module": assert not n.kwargs submod = gm.get_submodule(n.target) gm.delete_submodule(n.target) mutated_inputs = find_input_mutations(submod.graph) gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) # NB: we didn't actually change the graph, no need for recompile

def cudagraphs(model, inputs): model = partition_cudagraphs(model, inputs) apply_cuda_graphs(model) return model

aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
# aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful # for debugging and can serve as a perf baseline. # TODO(jansel): rename to just "cudagraphs"? register_backend(name="aot_cudagraphs", compiler_fn=aot_cudagraphs)