由于很多时候需要hack forward函数,修改其中的逻辑,目前学习到的有两种修改实例的方法,之后还会有补充。
__get__
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| class OriginalClass: def original_function(self): print("这是原始函数的执行内容")
def hijacked_function(self): print("这是被劫持替换后的函数执行内容")
if __name__ == "__main__": original_obj = OriginalClass() original_main_entry = original_obj.original_function original_obj.original_function = hijacked_function.__get__(original_obj) original_obj.original_function() another_obj = OriginalClass() another_obj.original_function()
|
types.MethodType
1 2 3 4 5 6 7 8 9 10 11 12 13
| import types
class MyClass: def original_method(self): print("执行原始方法")
def new_method(self): print("执行替换后的方法")
if __name__ == "__main__": obj = MyClass() obj.original_method = types.MethodType(new_method, obj) obj.original_method()
|
题外话:pytoch hack
pytorch支持导出、修改Module中间变量
1 2 3 4
| torch.Tensor.register_hook() torch.nn.Module.register_forward_hook() torch.nn.Module.register_backward_hook() torch.nn.Module.register_forward_pre_hook()
|