diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..6ea4d01d03 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -107,12 +107,6 @@ void setup_input_tensors( TORCHTRT_CHECK( inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); - auto expected_type = - util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); - TORCHTRT_CHECK( - inputs[i].dtype() == expected_type, - "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); - auto dims = core::util::toDims(inputs[i].sizes()); auto shape = core::util::toVec(dims); LOG_DEBUG("Input Name: " << name << " Shape: " << dims); diff --git a/examples/dynamo/autocast_example.py b/examples/dynamo/autocast_example.py new file mode 100644 index 0000000000..f1487cfb72 --- /dev/null +++ b/examples/dynamo/autocast_example.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +import torch_tensorrt + + +class AutocastExample(nn.Module): + def __init__(self): + super(AutocastExample, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1 + ) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) + self.conv2 = nn.Conv2d( + in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + self.fc1 = nn.Linear(16 * 8 * 8, 10) + + def forward(self, x, y): + x = self.conv1(x) # fp32 because of "^conv1$" in `autocast_excluded_nodes` + x = self.relu1(x) # fp32 because of "relu" in `autocast_excluded_nodes` + out = self.pool1(x) # fp16 + x = self.conv2(out) # fp16 + x = self.relu2(x) # fp32 because of "relu" in `autocast_excluded_nodes` + x = self.pool2(x) # fp16 + x = self.flatten( + x + ) # fp32 because of `torch.ops.aten.flatten.using_ints` in `autocast_excluded_ops` + # Respect the precisions in the pytorch autocast context + with torch.autocast(x.device.type, enabled=True, dtype=torch.float32): + x = self.fc1(x) + with torch.autocast(x.device.type, enabled=False): + x = torch.sub(x.half(), y) + out2 = torch.add(x, x) + with torch.autocast(x.device.type, enabled=True, dtype=torch.float16): + out2 = torch.log(out2) + return x, out, out2 + + +if __name__ == "__main__": + model = AutocastExample().cuda().eval() + inputs = ( + torch.randn((1, 3, 32, 32), dtype=torch.float32, device="cuda"), + torch.randn((1,), dtype=torch.float16, device="cuda"), + ) + + ep = torch.export.export(model, inputs) + + with torch_tensorrt.dynamo.Debugger( + "graphs", + logging_dir=".", + engine_builder_monitor=False, + ): + trt_mod = torch_tensorrt.compile( + ep.module(), + arg_inputs=inputs, + min_block_size=1, + use_python_runtime=True, + ##### weak typing ##### + # use_explicit_typing=False, + # enabled_precisions={torch.float16}, + ##### strong typing + autocast ##### + use_explicit_typing=True, + enable_autocast=True, + autocast_low_precision_type=torch.float16, + autocast_excluded_nodes={"^conv1$", "relu"}, + autocast_excluded_ops={torch.ops.aten.flatten.using_ints}, + autocast_data_max=512, + autocast_max_depth_of_reduction=None, + ) + + trt_out = trt_mod(*inputs) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c8ad938032..a78ae0a813 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -141,7 +141,7 @@ def cross_compile_for_windows( disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. - enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels + enabled_precisions (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels workspace_size (int): Maximum size of workspace given to TensorRT @@ -434,6 +434,16 @@ def compile( l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING, offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU, use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE, + enable_autocast: bool = _defaults.ENABLE_AUTOCAST, + autocast_low_precision_type: Optional[ + Union[torch.dtype, dtype] + ] = _defaults.AUTOCAST_LOW_PRECISION_TYPE, + autocast_excluded_nodes: Collection[str] = _defaults.AUTOCAST_EXCLUDED_NODES, + autocast_excluded_ops: Collection[Target] = _defaults.AUTOCAST_EXCLUDED_OPS, + autocast_data_max: float = _defaults.AUTOCAST_DATA_MAX, + autocast_max_depth_of_reduction: Optional[ + int + ] = _defaults.AUTOCAST_MAX_DEPTH_OF_REDUCTION, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -511,6 +521,12 @@ def compile( l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). offload_module_to_cpu (bool): Offload the module to CPU. This is useful when we need to minimize GPU memory usage. use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. + autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -584,6 +600,10 @@ def compile( "\nThis feature is unimplemented in Torch-TRT Dynamo currently." ) + if enable_autocast: + use_explicit_typing = True + logger.debug("Autocast is enabled, setting use_explicit_typing to True.") + if use_explicit_typing: if len(enabled_precisions) != 1 or not any( x in enabled_precisions @@ -593,6 +613,19 @@ def compile( f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: dtype.f32, dtype.f4). enabled_precisions should not be used when use_explicit_typing=True" ) + if autocast_low_precision_type is not None: + if not isinstance(autocast_low_precision_type, (torch.dtype, dtype)): + raise ValueError( + f"autocast_low_precision_type must be a torch.dtype or torch_tensorrt._enums.dtype, got {type(autocast_low_precision_type)}" + ) + if autocast_low_precision_type not in { + torch.float16, + torch.bfloat16, + } and autocast_low_precision_type not in {dtype.f16, dtype.bf16}: + raise ValueError( + f"autocast_low_precision_type must be one of torch.float16, torch.bfloat16, dtype.f16, dtype.bf16, got {autocast_low_precision_type}" + ) + if use_fp32_acc: logger.debug( "FP32 accumulation for matmul layers is enabled. This option should only be enabled if the model already has FP16 weights and has no effect if it has FP32 weights. \ @@ -622,6 +655,38 @@ def compile( if not isinstance(arg_inputs, collections.abc.Sequence): arg_inputs = [arg_inputs] # type: ignore + # save intermediate outputs of each node for Autocast + autocast_intermediate_node_outputs = {} + if not use_explicit_typing: + + class DumpInterpreter(torch.fx.Interpreter): # type: ignore[misc] + """Dump intermediate outputs of each node""" + + def run_node(self, n: torch.fx.Node) -> Any: + if ( + n.op == "call_function" + and n.target != torch.ops.higher_order.wrap_with_autocast + ): + out = super().run_node(n) + if not isinstance(out, torch.Tensor): + raise ValueError( + f"Please file a bug with Torch-TensorRT because it expects a torch.Tensor but got {type(out)} for node {n.name}." + ) + autocast_intermediate_node_outputs[n.name] = out + return out + return super().run_node(n) + + def _materialize(x: Input | torch.Tensor) -> torch.Tensor: + """Materialize an Input object to a tensor""" + if isinstance(x, Input): + return x.torch_tensor + return x + + with torch.no_grad(): + mat_args = tuple(_materialize(a) for a in arg_inputs) + mat_kwargs = {k: _materialize(v) for k, v in kwarg_inputs.items()} + DumpInterpreter(exported_program.module()).run(*mat_args, **mat_kwargs) + # Prepare torch_trt inputs trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs) trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs) @@ -680,6 +745,13 @@ def compile( "l2_limit_for_tiling": l2_limit_for_tiling, "offload_module_to_cpu": offload_module_to_cpu, "use_distributed_mode_trace": use_distributed_mode_trace, + "enable_autocast": enable_autocast, + "autocast_low_precision_type": autocast_low_precision_type, + "autocast_excluded_nodes": autocast_excluded_nodes, + "autocast_excluded_ops": autocast_excluded_ops, + "autocast_data_max": autocast_data_max, + "autocast_max_depth_of_reduction": autocast_max_depth_of_reduction, + "autocast_intermediate_node_outputs": autocast_intermediate_node_outputs, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index de970ecd81..a92fcf9d4e 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -57,6 +57,12 @@ L2_LIMIT_FOR_TILING = -1 USE_DISTRIBUTED_MODE_TRACE = False OFFLOAD_MODULE_TO_CPU = False +ENABLE_AUTOCAST = False +AUTOCAST_LOW_PRECISION_TYPE = None +AUTOCAST_EXCLUDED_NODES = set[str]() +AUTOCAST_EXCLUDED_OPS = set[torch.fx.node.Target]() +AUTOCAST_DATA_MAX = 512 +AUTOCAST_MAX_DEPTH_OF_REDUCTION = None if platform.system() == "Linux": import pwd diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index d8f6809eae..814e75f917 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,17 +1,24 @@ from dataclasses import dataclass, field from typing import Any, Collection, Optional, Set, Tuple, Union +import torch from torch.fx.node import Target from torch_tensorrt._Device import Device from torch_tensorrt._enums import EngineCapability, dtype from torch_tensorrt.dynamo._defaults import ( ASSUME_DYNAMIC_SHAPE_SUPPORT, + AUTOCAST_DATA_MAX, + AUTOCAST_EXCLUDED_NODES, + AUTOCAST_EXCLUDED_OPS, + AUTOCAST_LOW_PRECISION_TYPE, + AUTOCAST_MAX_DEPTH_OF_REDUCTION, CACHE_BUILT_ENGINES, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, + ENABLE_AUTOCAST, ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLE_WEIGHT_STREAMING, @@ -97,6 +104,13 @@ class CompilationSettings: tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"]. l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit). use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model + enable_autocast (bool): Whether to enable autocast. If enabled, use_explicit_typing will be set to True. + autocast_low_precision_type (Optional[Union[torch.dtype, dtype]]): The precision to reduce to. We currently support torch.float16 and torch.bfloat16. Default is None, which means no low precision is used. + autocast_excluded_nodes (Collection[str]): The set of regex patterns to match node names that should remain in FP32. Default is []. + autocast_excluded_ops (Collection[Target]): The set of targets (ATen ops) that should remain in FP32. Default is []. + autocast_data_max (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512. + autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. If not provided, infinity will be used. Default is None. + autocast_intermediate_node_outputs (dict[str, torch.Tensor]): The intermediate node outputs of the graph. Default is {}. """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -140,6 +154,19 @@ class CompilationSettings: l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + enable_autocast: bool = ENABLE_AUTOCAST + autocast_low_precision_type: Optional[dtype] = AUTOCAST_LOW_PRECISION_TYPE + autocast_excluded_nodes: Collection[str] = field( + default_factory=lambda: AUTOCAST_EXCLUDED_NODES + ) + autocast_excluded_ops: Collection[Target] = field( + default_factory=lambda: AUTOCAST_EXCLUDED_OPS + ) + autocast_data_max: float = AUTOCAST_DATA_MAX + autocast_max_depth_of_reduction: Optional[int] = AUTOCAST_MAX_DEPTH_OF_REDUCTION + autocast_intermediate_node_outputs: dict[str, torch.Tensor] = field( + default_factory=lambda: {} + ) def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -157,6 +184,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) +# If any of the following setting is changed, the engine should be rebuilt. _SETTINGS_TO_BE_ENGINE_INVARIANT = ( "enabled_precisions", "max_aux_streams", diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index e5183668ae..1499e670bd 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -15,6 +15,13 @@ from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .rule_based_autocast import rule_based_autocast + +pre_lowering_pass_list = [ + remove_detach, + rule_based_autocast, + remove_assert_nodes, # rule_based_autocast might insert assert nodes +] post_lowering_pass_list = [ remove_input_alias_fixing_clones, @@ -27,10 +34,6 @@ complex_graph_detection, ] -pre_lowering_pass_list = [ - remove_detach, -] - if not is_tegra_platform(): from .fuse_distributed_ops import fuse_distributed_ops diff --git a/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py new file mode 100644 index 0000000000..b7b7c770c3 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/nodeclassifier.py @@ -0,0 +1,304 @@ +# Borrowed from ModelOpt AutoCast's nodeclassifier.py, modified to fit Torch-TensorRT's needs. +import abc +import logging +import operator +import re +from typing import Collection, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class NodeRuleBase: + """Base class for node classification rules. + + This class defines the interface for rules that determine whether a node + should be kept in high precision or converted to low precision. + """ + + @abc.abstractmethod + def _check_inner(self, node): + """Implement this method to check if node conversion should be skipped based on rule criteria.""" + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes.""" + logger.info(f"Skipping node {node.name}: {self.__class__.__name__}") + + def check(self, node): + """Check if a node should be skipped based on the rule. + + Args: + node: The ONNX node to check. + + Returns: + bool: True if the node should be kept in high precision, False otherwise. + """ + result = self._check_inner(node) + if result: + self._log_skipped(node) + return True + return False + + +class DisabledNodeNameRegexRule(NodeRuleBase): + """Rule for keeping nodes with matching names in high precision.""" + + def __init__(self, disabled_node_name_regex): + """Initialize the rule. + + Args: + disabled_node_name_regex: List of regex patterns for node names to keep in high precision. + """ + self.disabled_node_name_regex = disabled_node_name_regex + + def _check_inner(self, node): + stack = node.meta.get("nn_module_stack") + node_name = next(reversed(stack), "").split("__")[ + -1 + ] # get the user specified name of the node + return any( + re.match(regex, node_name) for regex in self.disabled_node_name_regex + ) + + +class DisabledOpTypes(NodeRuleBase): + """Rule for keeping nodes with specific operation types in high precision.""" + + def __init__(self, excluded_ops): + """Initialize the rule. + + Args: + excluded_ops: List of operation types to keep in high precision. + """ + self.excluded_ops = excluded_ops + + def _check_inner(self, node): + return node.target in self.excluded_ops + + +class IORangeRule(NodeRuleBase): + """Rule for keeping nodes with out-of-range inputs/outputs in high precision.""" + + def __init__(self, data_max, reference_data): + """Initialize the rule. + + Args: + data_max: Maximum absolute value allowed for node I/O. + reference_data: Reference data for checking I/O ranges. + """ + self.data_max = data_max + self.reference_data = reference_data + self.output_data = None + + def _check_inner(self, node): + def is_io_out_of_range(node): + tensor_name = node.name + if tensor_name not in self.reference_data: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} not found in reference data. Skipping I/O range check." + ) + return False + ref_data = self.reference_data[tensor_name] + if ref_data.numel() == 0: + logger.debug( + f"Node {node.name}: Tensor {tensor_name} has 0 elements. Skipping I/O range check." + ) + return False + logger.debug( + f"Node {node.name}: reference data: min={ref_data.min()}, max={ref_data.max()}" + ) + if torch.any(torch.abs(ref_data) > self.data_max): + self.output_data = ref_data + return True + + if self.reference_data: + for in_node in node.all_input_nodes: + if is_io_out_of_range(in_node): + return True + for out_node in list(node.users): + if is_io_out_of_range(out_node): + return True + return False + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with I/O range violations.""" + if self.output_data is not None: + logger.info( + f"Skipping node {node.name}: reference IO out of range: min={torch.min(self.output_data)}, " + f"max={torch.max(self.output_data)}, range=[{-self.data_max}, {self.data_max}]" + ) + else: + super()._log_skipped(node, **kwargs) + + +class DepthOfReductionRule(NodeRuleBase): + """Rule for keeping nodes with high depth of reduction in high precision.""" + + def __init__(self, max_depth_of_reduction, reference_data): + """Initialize the rule. + + Args: + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + reference_data: Reference data for checking I/O ranges. + """ + self.max_depth_of_reduction = max_depth_of_reduction + self.reference_data = reference_data + self.reduction_depth = 0 + + def _get_tensor_shape(self, tensor_name): + """Get tensor shape from reference data.""" + if tensor_name in self.reference_data: + return self.reference_data[tensor_name].shape + return None + + def _log_skipped(self, node, **kwargs): + """Log information about skipped nodes with depth of reduction violations.""" + if self.reduction_depth > 0: + logger.info( + f"Skipping node {node.name}: depth of reduction {self.reduction_depth} exceeds " + f"{self.max_depth_of_reduction}." + ) + else: + super()._log_skipped(node, **kwargs) + + def _check_inner(self, node): + # All reduction ops rely on shape of input[0] + input_0_dims = ( + self._get_tensor_shape(node.all_input_nodes[0].name) + if len(node.all_input_nodes) > 0 + else None + ) + if input_0_dims is None: + return False + self.reduction_depth = 0 + if node.target in [ + torch.ops.aten.scaled_dot_product_attention.default, + ]: + # Attention: input (batch_size, sequence_length, hidden_size) + # or (batch_size, kv_num_heads, total_sequence_length, head_size) + assert len(input_0_dims) == 3 or len(input_0_dims) == 4 + hidden_size = ( + input_0_dims[2] + if len(input_0_dims) == 3 + else input_0_dims[1] * input_0_dims[3] + ) + self.reduction_depth = hidden_size + elif node.target in [ + torch.ops.aten.convolution.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + ]: + # Conv: input (N x C x D1 x D2 ... x Dn) + # weight (out_channels, in_channels, kD1, kD2, ... kDn) + # Reduction depth = in_channels * kernel_volume + weight_shape = ( + self._get_tensor_shape(node.all_input_nodes[1].name) + if len(node.all_input_nodes) > 1 + else None + ) + if weight_shape is None: + return False + in_channels = weight_shape[1] + kernel_volume = torch.prod(weight_shape[2:]) + self.reduction_depth = in_channels * kernel_volume + elif node.target in [ + torch.ops.aten.matmul, + torch.ops.aten.matmul.default, + torch.ops.aten.dot.default, + torch.ops.aten.mm.default, + torch.ops.aten.mv.default, + torch.ops.aten.bmm.default, + ]: + # GEMM: A (M, K) @ B (K, N) = C (M, N) + self.reduction_depth = input_0_dims[-1] + # TODO: Add more reduction ops here + return self.reduction_depth > self.max_depth_of_reduction + + +class NodeClassifier: + """Main class for classifying nodes into high and low precision groups.""" + + def __init__( + self, + nodes, + excluded_nodes: Collection[str] | None = None, + excluded_ops: Collection[torch.fx.node.Target] | None = None, + custom_rule: NodeRuleBase | None = None, + data_max: float | None = 1000.0, + max_depth_of_reduction: int | None = None, + ): + """Initialize the node classifier. + + Args: + nodes: The nodes to classify. + nodes_to_exclude: Collection of regex patterns for node names to keep in high precision. + targets_to_exclude: Collection of targets to keep in high precision. + custom_rule: Optional custom classification rule. + data_max: Maximum absolute value allowed for node I/O. + max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + """ + self.nodes = nodes + self.excluded_nodes = excluded_nodes + self.excluded_ops = excluded_ops + self.custom_rule = custom_rule + self.data_max = data_max + self.max_depth_of_reduction = max_depth_of_reduction + + def _gen_block_node_rules(self, reference_data): + """Generate list of rules for blocking nodes from precision conversion. + + Args: + reference_data: Reference data for checking I/O ranges. + + Returns: + list[NodeRuleBase]: List of rules to apply. + """ + block_node_rules: list[NodeRuleBase] = [] + if self.excluded_nodes: + block_node_rules.append(DisabledNodeNameRegexRule(self.excluded_nodes)) + if self.excluded_ops: + block_node_rules.append(DisabledOpTypes(self.excluded_ops)) + if reference_data: + block_node_rules.append(IORangeRule(self.data_max, reference_data)) + if self.max_depth_of_reduction is not None: + block_node_rules.append( + DepthOfReductionRule( + self.max_depth_of_reduction, + reference_data, + ) + ) + if self.custom_rule: + block_node_rules.append(self.custom_rule) + return block_node_rules + + def run( + self, ref_outputs_dict: Optional[dict[str, torch.Tensor]] = None + ) -> tuple[list[str], list[str]]: + """Run node classification. + + Args: + ref_outputs_dict: Optional tensors' reference data. + + Returns: + tuple: Lists of node names (low_precision_nodes, high_precision_nodes). + """ + block_node_rules = self._gen_block_node_rules(ref_outputs_dict) + low_precision_nodes = [] + high_precision_nodes = [] + for node in self.nodes: + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + # If any condition is met - node will be executed in high precision + if any(rule.check(node) for rule in block_node_rules): + high_precision_nodes.append(node.name) + else: + low_precision_nodes.append(node.name) + logger.debug(f"Low Precision Nodes: {low_precision_nodes}") + logger.debug(f"High Precision Nodes: {high_precision_nodes}") + return low_precision_nodes, high_precision_nodes diff --git a/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py new file mode 100644 index 0000000000..097e17a944 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py @@ -0,0 +1,113 @@ +import logging +import operator +from typing import Any + +import torch +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo._settings import CompilationSettings + +from .nodeclassifier import NodeClassifier +from .pass_utils import clean_up_graph_after_modifications + +logger = logging.getLogger(__name__) + + +def is_tensor_node(n: torch.fx.Node) -> bool: + val = n.meta.get("val", None) + if hasattr(val, "dtype"): + return True + return False + + +def rule_based_autocast( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Rule-based autocast""" + if not settings.enable_autocast: + logger.debug("Autocast is not enabled, skipping rule-based autocast.") + return gm + + # get config from settings + autocast_low_precision_type = settings.autocast_low_precision_type + if autocast_low_precision_type is None: + return gm + if isinstance(autocast_low_precision_type, dtype): + autocast_low_precision_type = autocast_low_precision_type.to(torch.dtype) + high_precision_type = torch.float32 + autocast_excluded_nodes = settings.autocast_excluded_nodes + autocast_excluded_ops = settings.autocast_excluded_ops + autocast_data_max = settings.autocast_data_max + autocast_max_depth_of_reduction = settings.autocast_max_depth_of_reduction + reference_data: dict[str, torch.Tensor] = ( + settings.autocast_intermediate_node_outputs + ) + + node_classifier = NodeClassifier( + gm.graph.nodes, + excluded_nodes=autocast_excluded_nodes, + excluded_ops=autocast_excluded_ops, + data_max=autocast_data_max, + max_depth_of_reduction=autocast_max_depth_of_reduction, + ) + low_precision_nodes, high_precision_nodes = node_classifier.run(reference_data) + + for node in list(gm.graph.nodes): + if node.op == "call_function": + if ( + node.target == torch.ops.higher_order.wrap_with_autocast + or node.target == operator.getitem + ): + continue + + def _cast_all_tensor_args_to_dtype(arg: Any, dtype: torch.dtype) -> Any: + """Cast all tensor args to the given dtype + + Args: + arg: The argument to cast + dtype: The dtype to cast to + + Returns: + The casted argument + """ + if isinstance(arg, torch.fx.Node) and is_tensor_node(arg): + val = arg.meta.get("val", None) + with gm.graph.inserting_before(node): + cast = gm.graph.call_function( + torch.ops.aten.to.dtype, args=(arg, dtype) + ) + + if isinstance(val, torch.Tensor): + arg.meta["val"] = val.to(dtype) + cast.meta.update(arg.meta) + return cast + elif isinstance(arg, (tuple, list)): + return type(arg)( + _cast_all_tensor_args_to_dtype(a, dtype) for a in arg + ) + elif isinstance(arg, dict): + return { + k: _cast_all_tensor_args_to_dtype(v, dtype) + for k, v in arg.items() + } + else: + return arg + + if node.name in low_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node.args, autocast_low_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node.kwargs, autocast_low_precision_type + ) + elif node.name in high_precision_nodes: + node.args = _cast_all_tensor_args_to_dtype( + node.args, high_precision_type + ) + node.kwargs = _cast_all_tensor_args_to_dtype( + node.kwargs, high_precision_type + ) + + gm = clean_up_graph_after_modifications(gm) + logger.debug("Graph after Autocast based on the rules:\n%s", gm.graph) + + return gm diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 9e54fbac3d..24166eb895 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -154,10 +154,6 @@ def forward( + contiguous_inputs[i + 1 :] ) - assert ( - contiguous_inputs[i].dtype == inputs[i].dtype - ), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}." - if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index d18a5674e0..0eb5ebbbca 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -275,10 +275,6 @@ def setup_engine(self) -> None: len(self.input_names) + len(self.output_names) ) - self.input_dtypes = [ - dtype._from(self.engine.get_tensor_dtype(input_name)) - for input_name in self.input_names - ] self.input_shapes = [ self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] @@ -371,10 +367,6 @@ def setup_input_tensors( + contiguous_inputs[i + 1 :] ) - assert ( - contiguous_inputs[i].dtype == self.input_dtypes[i] - ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}." - if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs # Clone is required to avoid re-using user-provided GPU memory