"""
Logging util
@Author: penhe@microsoft.com
"""
""" Utils for torch jit tracing customer operators/functions
"""
import os
def traceable(cls):
""" Decorator over customer functions
There is an issue for tracing customer python torch Function, using this decorator to work around it.
e.g.
@traceable
class MyOp(torch.autograd.Function):
xxx
"""
class _Function(object):
@staticmethod
def apply(*args):
jit_trace = (os.getenv('JIT_TRACE', 'False').lower() == 'true')
if jit_trace:
return cls.forward(_Function, *args)
else:
return cls.apply(*args)
@staticmethod
def save_for_backward(*args):
pass
_Function.__name__ = cls.__name__
_Function.__doc__ = cls.__doc__
return _Function
class TraceMode():
""" Trace context used when tracing modules contains customer operators/Functions
"""
def __enter__(self):
os.environ['JIT_TRACE'] = 'True'
return self
def __exit__(self, exp_value, exp_type, trace):
del os.environ['JIT_TRACE']