Source code for DeBERTa.utils.jit_tracing

"""
Logging util
@Author: penhe@microsoft.com
"""

""" Utils for torch jit tracing customer operators/functions
"""
import os

def traceable(cls):
  """ Decorator over customer functions
      There is an issue for tracing customer python torch Function, using this decorator to work around it.
      e.g.
      @traceable
      class MyOp(torch.autograd.Function):
      xxx
  """

  class _Function(object):
    @staticmethod
    def apply(*args):
      jit_trace = (os.getenv('JIT_TRACE', 'False').lower() == 'true')
      if jit_trace:
        return cls.forward(_Function, *args)
      else:
        return cls.apply(*args)

    @staticmethod
    def save_for_backward(*args):
      pass

  _Function.__name__ = cls.__name__
  _Function.__doc__ = cls.__doc__
  return _Function

class TraceMode():
  """ Trace context used when tracing modules contains customer operators/Functions
  """
  def __enter__(self):
    os.environ['JIT_TRACE'] = 'True'
    return self

  def __exit__(self, exp_value, exp_type, trace):
    del os.environ['JIT_TRACE']