Shortcuts

Source code for torch.jit._trace

"""Tracing

This module contains functionality to support the JIT's tracing frontend, notably:
    * torch.jit.trace
    * torch.jit.trace_module

This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import torch

import copy
import os
import contextlib
import functools
import warnings
import inspect
import re
from typing import Any, Dict, List, Optional, Set

from torch.jit._state import _python_cu, _enabled
from torch.jit._script import ScriptModule, _CachedForward, script
from torch._jit_internal import _qualified_name, is_scripting, get_callable_argument_names
from torch.autograd import function
from torch.nn import Module

from torch.testing._comparison import default_tolerances

_flatten = torch._C._jit_flatten
_unflatten = torch._C._jit_unflatten


def _create_interpreter_name_lookup_fn(frames_up=1):
    def _get_interpreter_name_for_var(var):
        frame = inspect.currentframe()
        if not frame:
            raise RuntimeError("failed to inspect frame")

        i = 0
        while i < frames_up + 1:
            frame = frame.f_back
            if not frame:
                raise RuntimeError("failed to get frame")
            i += 1

        f_locals = frame.f_locals
        f_globals = frame.f_globals

        for k, v in f_locals.items():
            if isinstance(v, torch.Tensor) and var is v:
                return k if k != "self" else ""
        return ""

    return _get_interpreter_name_for_var


def _unique_state_dict(module, keep_vars=False):
    # since Parameter.detach() always creates a new torch.Tensor instance,
    # id(v) doesn't work with it. So we always get the Parameter or Buffer
    # as values, and deduplicate the params using Parameters and Buffers
    state_dict = module.state_dict(keep_vars=True)
    filtered_dict = type(state_dict)()
    seen_ids: Set[int] = set()
    for k, v in state_dict.items():
        if id(v) in seen_ids:
            continue
        seen_ids.add(id(v))
        if keep_vars:
            filtered_dict[k] = v
        else:
            filtered_dict[k] = v.detach()
    return filtered_dict


class ONNXTracedModule(torch.nn.Module):
    def __init__(
        self,
        inner,
        strict=True,
        force_outplace=False,
        return_inputs=False,
        return_inputs_states=False,
    ):
        super(ONNXTracedModule, self).__init__()
        # inner may be a Module, or it may be an arbitrary callable
        # If it's a Module, we get its parameters automatically, which lets
        # us avoid a special casing functions versus modules.
        self.inner = inner
        self.strict = strict
        self._force_outplace = force_outplace
        self._return_inputs = return_inputs
        self._return_inputs_states = return_inputs_states

    def forward(self, *args: torch.Tensor):
        in_vars, in_desc = _flatten(args)
        # NOTE: use full state, because we need it for BatchNorm export
        # This differs from the compiler path, which doesn't support it at the moment.
        module_state = list(_unique_state_dict(self, keep_vars=True).values())

        ret_inputs = []
        inputs_states = []
        outs = []

        def wrapper(*args):
            in_args: List[torch.Tensor] = []
            for i in range(len(in_vars)):
                if not isinstance(args[i], torch.Tensor):
                    raise RuntimeError('Expected Tensor argument')
                in_args.append(args[i])

            trace_inputs = _unflatten(in_args, in_desc)

            ret_inputs.append(
                tuple(x.clone(memory_format=torch.preserve_format) for x in args)
            )
            if self._return_inputs_states:
                inputs_states.append(_unflatten(in_args, in_desc))
            outs.append(self.inner(*trace_inputs))
            if self._return_inputs_states:
                inputs_states[0] = (inputs_states[0], trace_inputs)
            out_vars, _ = _flatten(outs)
            if len(out_vars) == 1:
                return out_vars[0]
            else:
                return tuple(out_vars)

        graph, out = torch._C._create_graph_by_tracing(
            wrapper,
            in_vars + module_state,
            _create_interpreter_name_lookup_fn(),
            self.strict,
            self._force_outplace,
        )

        if self._return_inputs:
            return graph, outs[0], ret_inputs[0]
        if self._return_inputs_states:
            return graph, outs[0], inputs_states[0]
        else:
            return graph, outs[0]


def _clone_inputs(args):
    def clone_input(a):
        if a is None:
            return None
        elif isinstance(a, torch.Tensor):
            # TODO: figure out one liner to .clone() and set requires_grad
            v = (
                a.detach()
                .clone(memory_format=None if a.is_mkldnn else torch.preserve_format)
                .requires_grad_(a.requires_grad)
            )
            if a.grad is not None:
                v.grad = clone_input(v.grad)
            return v
        else:
            return a.clone(memory_format=torch.preserve_format)

    return function._nested_map(
        lambda x: isinstance(x, torch.Tensor), clone_input, condition_msg="tensors"
    )(args)


# This is purely for developer debugging.  We are not going to advertise it.
_JIT_TIME = os.environ.get("PYTORCH_JIT_TIME", False)  # CUDA-only timing
_JIT_DISABLE = os.environ.get("PYTORCH_JIT_DISABLE", False)
_JIT_STATS = os.environ.get("PYTORCH_JIT_STATS", False)


@contextlib.contextmanager
def _time(trace_name, name, time=True):
    if (not _JIT_TIME and not time) or not torch.cuda.is_available():
        yield
        return
    stream = torch.cuda.current_stream()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    stream.record_event(start)
    try:
        yield
    finally:
        stream.record_event(end)
        end.synchronize()
        print("{} {} time: {} ms".format(trace_name, name, start.elapsed_time(end)))


def verify(model, args, loss_fn=torch.sum, devices=None):
    """
    Verify that a JIT compiled model has the same behavior as its uncompiled
    version along with its backwards pass.  If your model returns multiple
    outputs, you must also specify a `loss_fn` to produce a loss for which
    the backwards will be computed.

    This function has side-effects (e.g., it executes your model / saves and loads
    parameters), so don't expect the model to come out exactly the same as what
    you passed in.

    Args:
        model (compiled torch.nn.Module or function): the module/function to be
            verified.  The module/function definition MUST have been decorated with
            `@torch.jit.compile`.
        args (tuple or Tensor): the positional arguments to pass to the
            compiled function/module to be verified.  A non-tuple is assumed to
            be a single positional argument to be passed to the model.
        loss_fn (function, optional): the loss function to be applied to
            the output of the model, before backwards is invoked.  By default,
            we assume that a model returns a single result, and we :func:`torch.sum`
            before calling backwards; if this is inappropriate, you can pass your
            own loss function.  Note that if a model returns a tuple of results,
            these are passed as separate positional arguments to `loss_fn`.
        devices (iterable of device IDs, optional): the GPU devices which the
            compiled module will be run on.  This determines the RNG state we
            must save when running both compiled and uncompiled versions of the model.
    """
    # TODO: In principle, we track device information in our trace, so it
    # should be possible to check if our execution actually obeyed the 'devices'
    # the user provided.

    # TODO: Consider adding a utility function to torch.jit to test
    # for this case
    if not isinstance(model, torch._C.CompiledFunction):  # type: ignore[attr-defined]
        raise TypeError(
            "Cannot verify an uncompiled module.  Add @torch.jit.compile to compile it"
        )
    is_module = isinstance(model, Module)

    if not isinstance(args, tuple):
        args = (args,)

    saved_args = _clone_inputs(args)
    if is_module:
        saved_state = copy.deepcopy(model.state_dict())

    def run_fwd_bwd(args, force_trace=False, assert_compiled=False):
        params = list(model.parameters()) if is_module else []
        in_vars, _ = _flatten((args, params))
        # We use a special API to reset the trace and compile it from scratch.
        compiled_fn = model
        if force_trace:
            compiled_fn.clear_cache()
        if assert_compiled:
            hits = compiled_fn.hits
        out = model(*args)
        if assert_compiled and compiled_fn.hits == hits:
            raise RuntimeError("failed to use the compiled function")
        if not isinstance(out, tuple):
            out = (out,)
        if loss_fn == torch.sum and len(out) != 1:
            raise ValueError(
                (
                    "Model returns {} outputs, but default loss function "
                    "(torch.sum) can only handle a single output"
                ).format(len(out))
            )
        out_vars, _ = _flatten(out)
        saved_outs = [
            v.detach().clone(memory_format=torch.preserve_format) for v in out_vars
        ]
        loss = loss_fn(*out)
        grads = torch.autograd.grad([loss], in_vars)
        # TODO: I'm not sure if the clone here is necessary but it is safer
        saved_grads = [
            v.detach().clone(memory_format=torch.preserve_format) for v in grads
        ]
        return (saved_outs, saved_grads)

    with torch.random.fork_rng(devices, _caller="torch.jit.verify"):
        uncompiled_outs, uncompiled_grads = run_fwd_bwd(args, force_trace=True)
        assert model.has_trace_for(*args)

    if is_module:
        model.load_state_dict(saved_state)
    compiled_outs, compiled_grads = run_fwd_bwd(args, assert_compiled=True)

    _verify_equal(uncompiled_outs, compiled_outs)
    _verify_equal(uncompiled_grads, compiled_grads)


def _verify_equal(xs, ys):
    for x, y in zip(xs, ys):
        if x.sub(y).abs().max() > 1e-6:
            raise RuntimeError("JIT and real computation mismatch")


def indent(s):
    return "\n".join(["\t" + line for line in s.splitlines()])


class TracingCheckError(Exception):
    def __init__(self, graph_diff_error, tensor_compare_error, extra_msg=None):
        self.message = "Tracing failed sanity checks!\n"
        if extra_msg is not None:
            self.message += extra_msg + "\n"
        if graph_diff_error is not None:
            self.message += "ERROR: Graphs differed across invocations!\n"
            self.message += indent(graph_diff_error) + "\n"
        if tensor_compare_error is not None:
            self.message += (
                "ERROR: Tensor-valued Constant nodes differed in value "
                "across invocations. This often indicates that the tracer has"
                " encountered untraceable code.\n"
            )
            self.message += indent(tensor_compare_error) + "\n"
        super(TracingCheckError, self).__init__(self.message)


# Check the traced module against a set of user-provided validation inputs
@torch.no_grad()
def _check_trace(
    check_inputs,
    func,
    traced_func,
    check_tolerance,
    strict,
    force_outplace,
    is_trace_module,
    _module_class,
):
    # Note: tracing is independent of optimizations, which consume the trace
    for inputs in check_inputs:

        if isinstance(inputs, torch.Tensor):
            inputs = (inputs,)

        if is_trace_module:
            copied_dict = {}
            for name, data in inputs.items():
                copied_dict[name] = _clone_inputs(data)
            check_mod = torch.jit.trace_module(
                func.__self__ if hasattr(func, "__self__") else func,
                copied_dict,
                check_trace=False,
                strict=strict,
                _force_outplace=force_outplace,
                _module_class=_module_class,
                _compilation_unit=torch._C.CompilationUnit(),
            )
            check_mod_func = check_mod._c._get_method(traced_func.name)
            inputs = inputs[traced_func.name]
            if isinstance(inputs, (torch.Tensor, dict)):
                inputs = (inputs,)
        else:
            check_mod = torch.jit.trace(
                func,
                _clone_inputs(inputs),
                check_trace=False,
                strict=strict,
                _force_outplace=force_outplace,
                _module_class=_module_class,
            )
            check_mod_func = check_mod

        def graph_diagnostic_info():
            mod_canonicalized = torch._C._jit_pass_canonicalize(traced_func.graph)
            torch._C._jit_pass_inline(mod_canonicalized)
            torch._C._jit_pass_erase_shape_information(mod_canonicalized)
            mod_str = str(mod_canonicalized)
            mod_str = re.sub(r"___torch_mangle_[0-9]+\.", "", mod_str)
            check_canonicalized = torch._C._jit_pass_canonicalize(check_mod_func.graph)
            torch._C._jit_pass_inline(check_canonicalized)
            torch._C._jit_pass_erase_shape_information(check_canonicalized)
            check_str = str(check_canonicalized)
            check_str = re.sub(r"___torch_mangle_[0-9]+\.", "", check_str)

            graph_diff_errors = None
            if mod_str != check_str:
                import difflib

                graph_diff = difflib.ndiff(
                    mod_str.splitlines(True), check_str.splitlines(True)
                )
                graph_diff_errors = "Graph diff:\n" + indent("".join(graph_diff)) + "\n"

                for n_mod, n_check in zip(
                    mod_canonicalized.nodes(), check_canonicalized.nodes()
                ):
                    if str(n_mod) != str(n_check):
                        graph_diff_errors += "First diverging operator:\n"
                        node_diff = difflib.ndiff(
                            str(n_mod).splitlines(True), str(n_check).splitlines(True)
                        )
                        source_printout = (
                            "Node diff:\n" + indent("".join(node_diff)) + "\n"
                        )
                        mod_stack = n_mod.sourceRange()
                        if mod_stack:
                            source_printout += (
                                "Trace source location:\n" + indent(mod_stack) + "\n"
                            )
                        check_stack = n_check.sourceRange()
                        if check_stack:
                            source_printout += (
                                "Check source location:\n" + indent(check_stack) + "\n"
                            )
                        graph_diff_errors += source_printout

                        break  # For now, only print out the first pair of nodes that diverges

            tensor_compare_errors = None
            # Check Tensor-valued constant nodes
            for n_mod, n_check in zip(
                mod_canonicalized.nodes(), check_canonicalized.nodes()
            ):
                if n_mod.kind() != n_check.kind():
                    break  # Graphs have already diverged

                if n_mod.kind() == "prim::Constant" and not (
                    n_mod.mustBeNone() or n_check.mustBeNone()
                ):
                    if not n_mod.hasAttribute("value"):
                        continue
                    if n_mod.kindOf("value") != "t" or n_check.kindOf("value") != "t":
                        continue

                    mod_tensor_val = n_mod.t("value")
                    check_tensor_val = n_check.t("value")

                    try:
                        torch.testing.assert_close(mod_tensor_val, check_tensor_val, equal_nan=True)
                    except (RuntimeError, AssertionError) as e:
                        if tensor_compare_errors is None:
                            tensor_compare_errors = ""
                        tensor_compare_errors += "Node:\n" + indent(str(n_mod)) + "\n"
                        compare_stack = n_mod.sourceRange()
                        if compare_stack:
                            tensor_compare_errors += (
                                "Source Location:\n" + indent(compare_stack) + "\n"
                            )
                        tensor_compare_errors += "Comparison exception: " + indent(
                            str(e)
                        )

                        break  # For now, only print the first diverging pair

            return graph_diff_errors, tensor_compare_errors

        def wrap_retval(x):
            return x if isinstance(x, tuple) else (x,)

        def run_mod_and_filter_tensor_outputs(mod, inputs, running_what):
            try:
                outs = wrap_retval(mod(*_clone_inputs(inputs)))
                outs = [out for out in outs if isinstance(out, torch.Tensor)]
                return outs
            except Exception as e:
                graph_diff_errors, tensor_compare_errors = graph_diagnostic_info()
                msg = f"encountered an exception while running the {running_what} with test inputs.\nException:\n{indent(str(e))}"
                raise TracingCheckError(
                    graph_diff_errors,
                    tensor_compare_errors,
                    extra_msg=msg,
                ) from e

        has_warned = [False]

        def maybe_warn_nondeterministic():
            if has_warned[0]:
                return
            has_warned[0] = True
            nondeterm_ops = [
                op for op in traced_func.graph.nodes() if op.isNondeterministic()
            ]
            if len(nondeterm_ops) > 0:
                nondeterministic_ops_warning = "Trace had nondeterministic nodes. "
                nondeterministic_ops_warning += (
                    "Did you forget call .eval() on your model? Nodes:\n"
                )
                nondeterministic_ops_warning += "\n".join(
                    [indent(str(op)) for op in nondeterm_ops][:20]
                )
                nondeterministic_ops_warning += (
                    "\nThis may cause errors in trace checking. To disable trace checking,"
                    " pass check_trace=False to torch.jit.trace()"
                )
                warnings.warn(
                    nondeterministic_ops_warning, category=TracerWarning, stacklevel=5
                )

        def compare_outputs(original, reference, match_what):
            all_ok = True
            for i, (orig, ref) in enumerate(zip(original, reference)):
                try:
                    if orig.is_quantized:
                        orig = orig.dequantize()
                    if ref.is_quantized:
                        ref = ref.dequantize()
                    if orig.is_mkldnn:
                        orig = orig.to_dense()
                    if ref.is_mkldnn:
                        ref = ref.to_dense()
                    if ref.is_complex() or orig.is_complex():
                        torch.testing.assert_close(
                            orig.to(torch.cdouble),
                            ref.to(torch.cdouble),
                            rtol=check_tolerance,
                            atol=default_tolerances(orig, ref)[1],
                            equal_nan=True,
                        )
                    else:
                        torch.testing.assert_close(
                            orig.double(),
                            ref.double(),
                            rtol=check_tolerance,
                            atol=default_tolerances(orig, ref)[1],
                            equal_nan=True,
                        )
                except AssertionError as e:
                    maybe_warn_nondeterministic()
                    warnings.warn(
                        "Output nr "
                        + str(i + 1)
                        + ". of the traced function does not match "
                        "the corresponding output of the "
                        + match_what
                        + ". Detailed error:\n"
                        + str(e),
                        category=TracerWarning,
                        stacklevel=4,
                    )
                    all_ok = False

            return all_ok

        traced_outs = run_mod_and_filter_tensor_outputs(traced_func, inputs, "trace")
        fn_outs = run_mod_and_filter_tensor_outputs(func, inputs, "Python function")
        if compare_outputs(traced_outs, fn_outs, "Python function"):
            check_outs = run_mod_and_filter_tensor_outputs(
                check_mod_func, inputs, "repeated trace"
            )
            compare_outputs(traced_outs, check_outs, "repeated trace")

        diag_info = graph_diagnostic_info()
        if any(info is not None for info in diag_info):
            raise TracingCheckError(*diag_info)


class TracerWarning(Warning):
    @staticmethod
    def ignore_lib_warnings():
        # We ignore warnings from all submodules excluding the JIT, because we need them e.g. for _check_trace
        warnings.filterwarnings(
            "ignore", category=TracerWarning, module="torch.(?!jit)"
        )


# We ignore the tracer warnings coming form inside the library, because all our shape
# checks in nn will trigger them.
TracerWarning.ignore_lib_warnings()
torch._C._tracer_warn_use_python()


def make_tuple(example_inputs):
    if isinstance(example_inputs, (torch.Tensor, dict)):
        return (example_inputs,)
    # done primarily so that weird iterables fail here and not pybind11 code
    if not isinstance(example_inputs, tuple):
        return tuple(example_inputs)
    return example_inputs


def make_module(mod, _module_class, _compilation_unit):
    if isinstance(mod, ScriptModule):
        return mod
    elif torch._jit_internal.module_has_exports(mod):

        infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
        return torch.jit._recursive.create_script_module(
            mod,
            infer_methods_stubs_fn,
            share_types=False,
            is_tracing=True
        )
    else:
        if _module_class is None:
            _module_class = TopLevelTracedModule
        return _module_class(mod, _compilation_unit=_compilation_unit)


def wrap_check_inputs(check_inputs):
    if check_inputs is None:
        return None

    return [{"forward": c} for c in check_inputs]


def trace(
    func,
    example_inputs,
    optimize=None,
    check_trace=True,
    check_inputs=None,
    check_tolerance=1e-5,
    strict=True,
    _force_outplace=False,
    _module_class=None,
    _compilation_unit=_python_cu,
):
    """
    Trace a function and return an executable  or :class:`ScriptFunction`
    that will be optimized using just-in-time compilation. Tracing is ideal for
    code that operates only on ``Tensor``\\s and lists, dictionaries, and
    tuples of ``Tensor``\\s.

    Using `torch.jit.trace` and `torch.jit.trace_module`, you can turn an
    existing module or Python function into a TorchScript
    :class:`ScriptFunction` or :class:`ScriptModule`. You must provide example
    inputs, and we run the function, recording the operations performed on all
    the tensors.

    * The resulting recording of a standalone function produces `ScriptFunction`.
    * The resulting recording of `nn.Module.forward` or `nn.Module` produces
      `ScriptModule`.

    This module also contains any parameters that the original
    module had as well.

    Warning:
        Tracing only correctly records functions and modules which are not data
        dependent (e.g., do not have conditionals on data in tensors) and do not have
        any untracked external dependencies (e.g., perform input/output or
        access global variables). Tracing only records operations done when the given
        function is run on the given tensors. Therefore, the returned
        `ScriptModule` will always run the same traced graph on any input. This
        has some important implications when your module is expected to run
        different sets of operations, depending on the input and/or the module
        state. For example,

        * Tracing will not record any control-flow like if-statements or loops.
          When this control-flow is constant across your module, this is fine
          and it often inlines the control-flow decisions. But sometimes the
          control-flow is actually part of the model itself. For instance, a
          recurrent network is a loop over the (possibly dynamic) length of an
          input sequence.
        * In the returned :class:`ScriptModule`, operations that have different
          behaviors in ``training`` and ``eval`` modes will always behave as if
          it is in the mode it was in during tracing, no matter which mode the
          `ScriptModule` is in.

        In cases like these, tracing would not be appropriate and
        :func:`scripting <torch.jit.script>` is a better choice. If you trace
        such models, you may silently get incorrect results on subsequent
        invocations of the model. The tracer will try to emit warnings when
        doing something that may cause an incorrect trace to be produced.

    Args:
        func (callable or torch.nn.Module):  A Python function or `torch.nn.Module`
            that will be run with `example_inputs`. `func` arguments and return
            values  must be tensors or (possibly nested) tuples that contain
            tensors. When a module is passed `torch.jit.trace`, only the
            ``forward`` method is run and traced (see :func:`torch.jit.trace
            <torch.jit.trace_module>` for details).
        example_inputs (tuple or torch.Tensor):  A tuple of example inputs that
            will be passed to the function while tracing. The resulting trace
            can be run with inputs of different types and shapes assuming the
            traced operations support those types and shapes. `example_inputs`
            may also be a single Tensor in which case it is automatically
            wrapped in a tuple.

    Keyword arguments:
        check_trace (``bool``, optional): Check if the same inputs run through
            traced code produce the same outputs. Default: ``True``. You might want
            to disable this if, for example, your network contains non-
            deterministic ops or if you are sure that the network is correct despite
            a checker failure.

        check_inputs (list of tuples, optional): A list of tuples of input
            arguments that should be used to check the trace against what is
            expected. Each tuple is equivalent to a set of input arguments that
            would be specified in ``example_inputs``. For best results, pass in
            a set of checking inputs representative of the space of shapes and
            types of inputs you expect the network to see.  If not specified,
            the original ``example_inputs`` are used for checking
        check_tolerance (float, optional): Floating-point comparison tolerance
            to use in the checker procedure.  This can be used to relax the
            checker strictness in the event that results diverge numerically
            for a known reason, such as operator fusion.
        strict (``bool``, optional): run the tracer in a strict mode or not
            (default: ``True``). Only turn this off when you want the tracer to
            record your mutable container types (currently ``list``/``dict``)
            and you are sure that the container you are using in your
            problem is a ``constant`` structure and does not get used as
            control flow (if, for) conditions.

    Returns:
        If `func` is `nn.Module` or ``forward`` of `nn.Module`, `trace` returns
        a :class:`ScriptModule` object with a single ``forward`` method
        containing the traced code.  The returned `ScriptModule` will
        have the same set of sub-modules and parameters as the original
        ``nn.Module``.  If ``func`` is a standalone function, ``trace``
        returns `ScriptFunction`.

    Example (tracing a function):

    .. testcode::

        import torch

        def foo(x, y):
            return 2 * x + y

        # Run `foo` with the provided inputs and record the tensor operations
        traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

        # `traced_foo` can now be run with the TorchScript interpreter or saved
        # and loaded in a Python-free environment

    Example (tracing an existing module)::

        import torch
        import torch.nn as nn

        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv = nn.Conv2d(1, 1, 3)

            def forward(self, x):
                return self.conv(x)

        n = Net()
        example_weight = torch.rand(1, 1, 3, 3)
        example_forward_input = torch.rand(1, 1, 3, 3)

        # Trace a specific method and construct `ScriptModule` with
        # a single `forward` method
        module = torch.jit.trace(n.forward, example_forward_input)

        # Trace a module (implicitly traces `forward`) and construct a
        # `ScriptModule` with a single `forward` method
        module = torch.jit.trace(n, example_forward_input)

    """
    if not _enabled:
        return func
    if optimize is not None:
        warnings.warn(
            "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
        )

    if isinstance(func, torch.jit.ScriptModule):
        # it is hard to trace it because the forward method on ScriptModule is already defined, so it
        # would result in an error.
        warnings.warn(
            "The input to trace is already a ScriptModule, tracing it is a no-op. Returning the object as is."
        )
        return func

    if isinstance(func, torch.nn.Module):
        return trace_module(
            func,
            {"forward": example_inputs},
            None,
            check_trace,
            wrap_check_inputs(check_inputs),
            check_tolerance,
            strict,
            _force_outplace,
            _module_class,
        )

    if (
        hasattr(func, "__self__")
        and isinstance(func.__self__, torch.nn.Module)
        and func.__name__ == "forward"
    ):
        return trace_module(
            func.__self__,
            {"forward": example_inputs},
            None,
            check_trace,
            wrap_check_inputs(check_inputs),
            check_tolerance,
            strict,
            _force_outplace,
            _module_class,
        )

    # Special case for common case of passing a single Tensor
    if isinstance(example_inputs, (torch.Tensor, dict)):
        example_inputs = (example_inputs,)
    # done primarily so that weird iterables fail here and not pybind11 code
    elif not isinstance(example_inputs, tuple):
        example_inputs = tuple(example_inputs)

    var_lookup_fn = _create_interpreter_name_lookup_fn(0)

    if hasattr(func, "__self__") and isinstance(func.__self__, torch.nn.Module):
        raise AttributeError(
            "trace doesn't support compiling individual module's functions.\n"
            "Please use trace_module"
        )

    name = _qualified_name(func)
    traced = torch._C._create_function_from_trace(
        name,
        func,
        example_inputs,
        var_lookup_fn,
        strict,
        _force_outplace,
        get_callable_argument_names(func)
    )

    # Check the trace against new traces created from user-specified inputs
    if check_trace:
        if check_inputs is not None:
            _check_trace(
                check_inputs,
                func,
                traced,
                check_tolerance,
                strict,
                _force_outplace,
                False,
                _module_class,
            )
        else:
            _check_trace(
                [example_inputs],
                func,
                traced,
                check_tolerance,
                strict,
                _force_outplace,
                False,
                _module_class,
            )

    return traced


_trace_module_map: Optional[Dict[Any, Any]] = None


def trace_module(
    mod,
    inputs,
    optimize=None,
    check_trace=True,
    check_inputs=None,
    check_tolerance=1e-5,
    strict=True,
    _force_outplace=False,
    _module_class=None,
    _compilation_unit=_python_cu,
):
    """
    Trace a module and return an executable :class:`ScriptModule` that will be optimized
    using just-in-time compilation. When a module is passed to :func:`torch.jit.trace <torch.jit.trace>`, only
    the ``forward`` method is run and traced. With ``trace_module``, you can specify a dictionary of
    method names to example inputs to trace (see the ``inputs``) argument below.

    See :func:`torch.jit.trace <torch.jit.trace>` for more information on tracing.

    Args:
        mod (torch.nn.Module):  A ``torch.nn.Module`` containing methods whose names are
                                specified in ``inputs``. The given methods will be compiled
                                as a part of a single `ScriptModule`.
        inputs (dict):  A dict containing sample inputs indexed by method names in ``mod``.
                                The inputs will be passed to methods whose names correspond to inputs'
                                keys while tracing.
                                ``{ 'forward' : example_forward_input, 'method2': example_method2_input}``
    Keyword arguments:
        check_trace (``bool``, optional): Check if the same inputs run through
                                      traced code produce the same outputs. Default: ``True``. You might want
                                      to disable this if, for example, your network contains non-
                                      deterministic ops or if you are sure that the network is correct despite
                                      a checker failure.

        check_inputs (list of dicts, optional): A list of dicts of input arguments that should be used
                                                 to check the trace against what is expected. Each tuple
                                                 is equivalent to a set of input arguments that would
                                                 be specified in ``inputs``. For best results, pass in a
                                                 set of checking inputs representative of the space of
                                                 shapes and types of inputs you expect the network to see.
                                                 If not specified, the original ``inputs`` are used for checking
        check_tolerance (float, optional): Floating-point comparison tolerance to use in the checker procedure.
                                           This can be used to relax the checker strictness in the event that
                                           results diverge numerically for a known reason, such as operator fusion.

    Returns:
        A :class:`ScriptModule` object with a single ``forward`` method containing the traced code.
        When ``func`` is a ``torch.nn.Module``, the returned :class:`ScriptModule` will have the same set of
        sub-modules and parameters as ``func``.

    Example (tracing a module with multiple methods)::

        import torch
        import torch.nn as nn

        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv = nn.Conv2d(1, 1, 3)

            def forward(self, x):
                return self.conv(x)

            def weighted_kernel_sum(self, weight):
                return weight * self.conv.weight


        n = Net()
        example_weight = torch.rand(1, 1, 3, 3)
        example_forward_input = torch.rand(1, 1, 3, 3)

        # Trace a specific method and construct `ScriptModule` with
        # a single `forward` method
        module = torch.jit.trace(n.forward, example_forward_input)

        # Trace a module (implicitly traces `forward`) and construct a
        # `ScriptModule` with a single `forward` method
        module = torch.jit.trace(n, example_forward_input)

        # Trace specific methods on a module (specified in `inputs`), constructs
        # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
        inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight}
        module = torch.jit.trace_module(n, inputs)

    """
    if not _enabled:
        return mod
    if optimize is not None:
        warnings.warn(
            "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
        )

    var_lookup_fn = _create_interpreter_name_lookup_fn(0)

    if not isinstance(mod, torch.nn.Module):
        raise AttributeError("expected torch.nn.Module as the first argument")

    if not isinstance(inputs, dict):
        raise AttributeError("expected a dictionary of (method_name, input) pairs")

    old_module_map = torch.jit._trace._trace_module_map
    try:
        trace_module_map: Dict[Any, Any] = {}

        def register_submods(mod, prefix):
            for name, child in mod.named_children():
                submod_qualname = prefix + "." + name
                trace_module_map[child] = submod_qualname
                register_submods(child, submod_qualname)

        trace_module_map["__module"] = mod
        torch.jit._trace._trace_module_map = trace_module_map
        register_submods(mod, "__module")

        module = make_module(mod, _module_class, _compilation_unit)

        for method_name, example_inputs in inputs.items():
            if method_name == "forward":
                # "forward" is a special case because we need to trace
                # `Module.__call__`, which sets up some extra tracing, but uses
                # argument names of the real `Module.forward` method.
                func = mod
                forward_method = getattr(mod, method_name)
                argument_names = get_callable_argument_names(forward_method)
            else:
                func = getattr(mod, method_name)
                argument_names = get_callable_argument_names(func)

            example_inputs = make_tuple(example_inputs)

            module._c._create_method_from_trace(
                method_name,
                func,
                example_inputs,
                var_lookup_fn,
                strict,
                _force_outplace,
                argument_names,
            )
            check_trace_method = module._c._get_method(method_name)

            # Check the trace against new traces created from user-specified inputs
            if check_trace:
                if check_inputs is not None:
                    _check_trace(
                        check_inputs,
                        func,
                        check_trace_method,
                        check_tolerance,
                        strict,
                        _force_outplace,
                        True,
                        _module_class,
                    )
                else:
                    _check_trace(
                        [inputs],
                        func,
                        check_trace_method,
                        check_tolerance,
                        strict,
                        _force_outplace,
                        True,
                        _module_class,
                    )
    finally:
        torch.jit._trace._trace_module_map = old_module_map

    return module


[docs]def is_tracing(): """ Returns ``True`` in tracing (if a function is called during the tracing of code with ``torch.jit.trace``) and ``False`` otherwise. """ if is_scripting(): return False return torch._C._is_tracing()
class TracedModule(ScriptModule): _disable_script_meta = True def __init__(self, orig, id_set=None, _compilation_unit=None): # XXX: orig can be a nn.Module or a function! super(TracedModule, self).__init__() assert isinstance(orig, torch.nn.Module) # Copy a subset of `orig` to a temporary nn.Module. # This is a way to customize what will actually get compiled by create_script_module id_set = set() # This allows us to preserve the original module's qualified name by defining a new # type with the attribute _jit_override_qualname. In torch._jit_internal._qualified_name # we have a special case that will look up this attribute to override whatever qualname # we would get from the python type system class QualnameWrapper(torch.nn.Module): pass QualnameWrapper._jit_override_qualname = torch._jit_internal._qualified_name( # type: ignore[attr-defined] type(orig) ) tmp_module = QualnameWrapper() def check_unique(param): if param in id_set: raise ValueError( "TracedModules don't support parameter sharing between modules" ) id_set.add(param) tmp_module.training = orig.training for name, param in orig._parameters.items(): if param is not None: tmp_module._parameters[name] = param check_unique(param) for name, buf in orig._buffers.items(): if buf is not None: tmp_module._buffers[name] = buf check_unique(buf) for name, val in orig.__dict__.items(): if ( torch._C._jit_is_script_object(val) and name not in orig._parameters and name not in orig._buffers ): setattr(tmp_module, name, val) if orig._backward_hooks: raise ValueError( "Modules that have backward hooks assigned can't be compiled: " + str(orig) ) for name, submodule in orig._modules.items(): if submodule is None: continue tmp_module._modules[name] = make_module( submodule, TracedModule, _compilation_unit=None ) script_module = torch.jit._recursive.create_script_module( tmp_module, lambda module: (), share_types=False, is_tracing=True ) self.__dict__["_name"] = type(orig).__name__ self.__dict__["_actual_script_module"] = script_module for name in ("_parameters", "_buffers", "_modules", "training"): delattr(self, name) def forward(self, *args, **kwargs): raise RuntimeError("Trace submodules cannot be called.") def __getattr__(self, attr): if "_actual_script_module" not in self.__dict__: return super(TracedModule, self).__getattr__(attr) return getattr(self._actual_script_module, attr) def __setattr__(self, attr, value): if "_actual_script_module" not in self.__dict__: return super(TracedModule, self).__setattr__(attr, value) setattr(self._actual_script_module, attr, value) def _get_name(self): return self._name def extra_repr(self): return "original_name={}".format(self._name) class TopLevelTracedModule(TracedModule): forward = _CachedForward() def _reconstruct(self, cpp_module): """ Re-construct an instance of TopLevelTracedModule using an instance of a C++ module. Args: cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around. """ self.__dict__["_actual_script_module"]._reconstruct(cpp_module) def _script_if_tracing(fn): @functools.wraps(fn) def wrapper(*args, **kwargs): if not is_tracing(): # Not tracing, don't do anything return fn(*args, **kwargs) compiled_fn = script(wrapper.__original_fn) # type: ignore[attr-defined] return compiled_fn(*args, **kwargs) wrapper.__original_fn = fn # type: ignore[attr-defined] wrapper.__script_if_tracing_wrapper = True # type: ignore[attr-defined] return wrapper def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False, return_inputs=False, _return_inputs_states=False): """ .. warning:: This function is internal-only and should only be used by the ONNX exporter. If you are trying to get a graph through tracing, please go through the public API instead:: trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) trace_graph = trace.graph Trace a function or model, returning a tuple consisting of the both the *trace* of an execution, as well as the original return value. If return_inputs, also returns the trace inputs as part of the tuple Tracing is guaranteed not to change the semantics of the function/module that is traced. Args: f (torch.nn.Module or function): the function or module to be traced. args (tuple or Tensor): the positional arguments to pass to the function/module to be traced. A non-tuple is assumed to be a single positional argument to be passed to the model. kwargs (dict): the keyword arguments to pass to the function/module to be traced. Example (trace a cell): .. testcode:: trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) """ if kwargs is None: kwargs = {} if not isinstance(args, tuple): args = (args,) outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) return outs

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources