Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: optimize_model() fails on HF's GPT2 with "RuntimeError: CUDA error: operation not permitted when stream is capturing" #336

Open
1 of 2 tasks
CorentinJ opened this issue Nov 9, 2023 · 4 comments

Comments

@CorentinJ
Copy link

Description

After calling optimize_model() on a GPT2Model instance from HuggingFace's transformers, the model's forward pass will raise a RuntimeError: CUDA error: operation not permitted when stream is capturing.

A very similar issue is #15002 but it was closed without solution.

Steps to reproduce

from transformers import GPT2Config, GPT2Model
import torch
from kernl.model_optimization import optimize_model

model = GPT2Model(GPT2Config()).eval().cuda()
optimize_model(model)

with torch.cuda.amp.autocast():
    print(model(torch.tensor([[0]], device="cuda")))

Expected Behavior

The model's output should be printed, as would be the case without the line optimize_model(model)

Actual Behavior

Stack trace

root@f8a6de0f637b:~/project# python _debug.py
[2023-11-09 15:36:32,385] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/fx/graph_module.py", line 302, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1519, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.60", line 37, in forward
    mask_value = full_1.to(device(type='cuda', index=0));  full_1 = None
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


Call using an FX-traced Module, line 37 of the traced Module's generated forward function:
    full_1 = torch.full([], -65504.0, dtype = torch.float16)
    mask_value = full_1.to(device(type='cuda', index=0));  full_1 = None

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    to_1 = attn_weights_1.to(torch.float16);  attn_weights_1 = None

    attn_weights_2 = torch.where(causal_mask, to_1, mask_value);  causal_mask = to_1 = mask_value = None

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/project/_debug.py:431 in <module>                                                     │
│                                                                                                  │
│   428 optimize_model(model)                                                                      │
│   429                                                                                            │
│   430 with torch.cuda.amp.autocast():                                                            │
│ ❱ 431 │   print(model(torch.tensor([[0]], device="cuda")))                                       │
│   432                                                                                            │
│   433                                                                                            │
│   434 quit()                                                                                     │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1510 in _wrapped_call_impl    │
│                                                                                                  │
│   1507 │   │   if self._compiled_call_impl is not None:                                          │
│   1508 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]        │
│   1509 │   │   else:                                                                             │
│ ❱ 1510 │   │   │   return self._call_impl(*args, **kwargs)                                       │
│   1511 │                                                                                         │
│   1512 │   def _call_impl(self, *args, **kwargs):                                                │
│   1513 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.fo  │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1519 in _call_impl            │
│                                                                                                  │
│   1516 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1517 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1518 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1519 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1520 │   │                                                                                     │
│   1521 │   │   try:                                                                              │
│   1522 │   │   │   result = None                                                                 │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py:409 in _fn                   │
│                                                                                                  │
│    406 │   │   │   dynamic_ctx = enable_dynamic(self.dynamic, self.export)                       │
│    407 │   │   │   dynamic_ctx.__enter__()                                                       │
│    408 │   │   │   try:                                                                          │
│ ❱  409 │   │   │   │   return fn(*args, **kwargs)                                                │
│    410 │   │   │   finally:                                                                      │
│    411 │   │   │   │   set_eval_frame(prior)                                                     │
│    412 │   │   │   │   dynamic_ctx.__exit__(None, None, None)                                    │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py:570 in catch_errors          │
│                                                                                                  │
│    567 │   │   │   │   │   return hijacked_callback(frame, cache_entry, hooks, frame_state)      │
│    568 │   │                                                                                     │
│    569 │   │   with compile_lock, _disable_current_modes():                                      │
│ ❱  570 │   │   │   return callback(frame, cache_entry, hooks, frame_state)                       │
│    571 │                                                                                         │
│    572 │   catch_errors._torchdynamo_orig_callable = callback  # type: ignore[attr-defined]      │
│    573 │   return catch_errors                                                                   │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py:668 in _convert_frame     │
│                                                                                                  │
│   665 │   def _convert_frame(frame: types.FrameType, cache_entry, hooks: Hooks, frame_state):    │
│   666 │   │   counters["frames"]["total"] += 1                                                   │
│   667 │   │   try:                                                                               │
│ ❱ 668 │   │   │   result = inner_convert(frame, cache_entry, hooks, frame_state)                 │
│   669 │   │   │   counters["frames"]["ok"] += 1                                                  │
│   670 │   │   │   return result                                                                  │
│   671 │   │   except Exception as e:                                                             │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py:377 in                    │
│ _convert_frame_assert                                                                            │
│                                                                                                  │
│   374 │   │   │   },                                                                             │
│   375 │   │   )                                                                                  │
│   376 │   │                                                                                      │
│ ❱ 377 │   │   return _compile(                                                                   │
│   378 │   │   │   frame.f_code,                                                                  │
│   379 │   │   │   frame.f_globals,                                                               │
│   380 │   │   │   frame.f_locals,                                                                │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py:592 in _compile           │
│                                                                                                  │
│   589 │                                                                                          │
│   590 │   with compile_context(CompileContext(compile_id)):                                      │
│   591 │   │   try:                                                                               │
│ ❱ 592 │   │   │   guarded_code = compile_inner(code, one_graph, hooks, transform)                │
│   593 │   │   │   return guarded_code                                                            │
│   594 │   │   except (                                                                           │
│   595 │   │   │   Unsupported,                                                                   │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:243 in time_wrapper               │
│                                                                                                  │
│    240 │   │   │   │   compilation_time_metrics[key] = []                                        │
│    241 │   │   │   with torch.profiler.record_function(f"{key} (dynamo_timed)"):                 │
│    242 │   │   │   │   t0 = time.time()                                                          │
│ ❱  243 │   │   │   │   r = func(*args, **kwargs)                                                 │
│    244 │   │   │   │   time_spent = time.time() - t0                                             │
│    245 │   │   │   compilation_time_metrics[key].append(time_spent)                              │
│    246 │   │   │   if phase_name:                                                                │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py:509 in compile_inner      │
│                                                                                                  │
│   506 │   │   for attempt in itertools.count():                                                  │
│   507 │   │   │   CompileContext.get().attempt = attempt                                         │
│   508 │   │   │   try:                                                                           │
│ ❱ 509 │   │   │   │   out_code = transform_code_object(code, transform)                          │
│   510 │   │   │   │   orig_code_map[out_code] = code                                             │
│   511 │   │   │   │   break                                                                      │
│   512 │   │   │   except exc.RestartAnalysis as e:                                               │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py:1033 in         │
│ transform_code_object                                                                            │
│                                                                                                  │
│   1030 │   instructions = cleaned_instructions(code, safe)                                       │
│   1031 │   propagate_line_nums(instructions)                                                     │
│   1032 │                                                                                         │
│ ❱ 1033 │   transformations(instructions, code_options)                                           │
│   1034 │   return clean_and_assemble_instructions(instructions, keys, code_options)[1]           │
│   1035                                                                                           │
│   1036                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py:150 in _fn                │
│                                                                                                  │
│   147 │   │   torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result          │
│   148 │   │   cleanup = setup_compile_debug()                                                    │
│   149 │   │   try:                                                                               │
│ ❱ 150 │   │   │   return fn(*args, **kwargs)                                                     │
│   151 │   │   finally:                                                                           │
│   152 │   │   │   cleanup.close()                                                                │
│   153 │   │   │   torch._C._set_grad_enabled(prior_grad_mode)                                    │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py:474 in transform          │
│                                                                                                  │
│   471 │   │                                                                                      │
│   472 │   │   try:                                                                               │
│   473 │   │   │   with tracing(tracer.output.tracing_context), tracer.set_current_tx():          │
│ ❱ 474 │   │   │   │   tracer.run()                                                               │
│   475 │   │   except exc.UnspecializeRestartAnalysis:                                            │
│   476 │   │   │   speculation_log.clear()                                                        │
│   477 │   │   │   raise                                                                          │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py:2120 in run            │
│                                                                                                  │
│   2117 │   │   │   │   │   self._freevars_ids[name] = id(f_locals[name])                         │
│   2118 │                                                                                         │
│   2119 │   def run(self):                                                                        │
│ ❱ 2120 │   │   super().run()                                                                     │
│   2121 │                                                                                         │
│   2122 │   def match_nested_cell(self, name, cell):                                              │
│   2123 │   │   """Match a cell in this method to one in a function we are inlining"""            │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py:815 in run             │
│                                                                                                  │
│    812 │   │   │   │   while (                                                                   │
│    813 │   │   │   │   │   self.instruction_pointer is not None                                  │
│    814 │   │   │   │   │   and not self.output.should_exit                                       │
│ ❱  815 │   │   │   │   │   and self.step()                                                       │
│    816 │   │   │   │   ):                                                                        │
│    817 │   │   │   │   │   pass                                                                  │
│    818 │   │   │   except BackendCompilerFailed:                                                 │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py:778 in step            │
│                                                                                                  │
│    775 │   │   │   TracingContext.set_current_loc(                                               │
│    776 │   │   │   │   self.f_code.co_filename, self.lineno, self.f_code.co_name                 │
│    777 │   │   │   )                                                                             │
│ ❱  778 │   │   │   getattr(self, inst.opname)(inst)                                              │
│    779 │   │   │                                                                                 │
│    780 │   │   │   return inst.opname != "RETURN_VALUE"                                          │
│    781 │   │   except Unsupported:                                                               │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py:2235 in RETURN_VALUE   │
│                                                                                                  │
│   2232 │   │   │   f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",             │
│   2233 │   │   )                                                                                 │
│   2234 │   │   log.debug("RETURN_VALUE triggered compile")                                       │
│ ❱ 2235 │   │   self.output.compile_subgraph(                                                     │
│   2236 │   │   │   self,                                                                         │
│   2237 │   │   │   reason=GraphCompileReason(                                                    │
│   2238 │   │   │   │   "return_value", [self.frame_summary()], graph_break=False                 │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py:906 in compile_subgraph    │
│                                                                                                  │
│    903 │   │   │   output = []                                                                   │
│    904 │   │   │   if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:             │
│    905 │   │   │   │   output.extend(                                                            │
│ ❱  906 │   │   │   │   │   self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)   │
│    907 │   │   │   │   )                                                                         │
│    908 │   │   │   │                                                                             │
│    909 │   │   │   │   if len(pass2.graph_outputs) != 0:                                         │
│                                                                                                  │
│ /usr/lib/python3.11/contextlib.py:81 in inner                                                    │
│                                                                                                  │
│    78 │   │   @wraps(func)                                                                       │
│    79 │   │   def inner(*args, **kwds):                                                          │
│    80 │   │   │   with self._recreate_cm():                                                      │
│ ❱  81 │   │   │   │   return func(*args, **kwds)                                                 │
│    82 │   │   return inner                                                                       │
│    83                                                                                            │
│    84                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py:1025 in                    │
│ compile_and_call_fx_graph                                                                        │
│                                                                                                  │
│   1022 │   │   )                                                                                 │
│   1023 │   │   self.call_cleanup_hooks()                                                         │
│   1024 │   │   with self.restore_global_state():                                                 │
│ ❱ 1025 │   │   │   compiled_fn = self.call_user_compiler(gm)                                     │
│   1026 │   │   compiled_fn = disable(compiled_fn)                                                │
│   1027 │   │                                                                                     │
│   1028 │   │   counters["stats"]["unique_graphs"] += 1                                           │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py:243 in time_wrapper               │
│                                                                                                  │
│    240 │   │   │   │   compilation_time_metrics[key] = []                                        │
│    241 │   │   │   with torch.profiler.record_function(f"{key} (dynamo_timed)"):                 │
│    242 │   │   │   │   t0 = time.time()                                                          │
│ ❱  243 │   │   │   │   r = func(*args, **kwargs)                                                 │
│    244 │   │   │   │   time_spent = time.time() - t0                                             │
│    245 │   │   │   compilation_time_metrics[key].append(time_spent)                              │
│    246 │   │   │   if phase_name:                                                                │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py:1096 in call_user_compiler │
│                                                                                                  │
│   1093 │   │   │   # aborting execution.                                                         │
│   1094 │   │   │   raise e                                                                       │
│   1095 │   │   except Exception as e:                                                            │
│ ❱ 1096 │   │   │   raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(              │
│   1097 │   │   │   │   e.__traceback__                                                           │
│   1098 │   │   │   ) from None                                                                   │
│   1099                                                                                           │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/output_graph.py:1077 in call_user_compiler │
│                                                                                                  │
│   1074 │   │   │   compiler_fn = self.compiler_fn                                                │
│   1075 │   │   │   if config.verify_correctness:                                                 │
│   1076 │   │   │   │   compiler_fn = WrapperBackend(compiler_fn)                                 │
│ ❱ 1077 │   │   │   compiled_fn = compiler_fn(gm, self.example_inputs())                          │
│   1078 │   │   │   _step_logger()(logging.INFO, f"done compiler function {name}")                │
│   1079 │   │   │   assert callable(compiled_fn), "compiler_fn did not return callable"           │
│   1080 │   │   except exceptions_allowed_to_be_fallback as e:                                    │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_dynamo/repro/after_dynamo.py:117 in debug_wrapper │
│                                                                                                  │
│   114 │   │   │   │   │   add_paths(exc)                                                         │
│   115 │   │   │   │   │   raise                                                                  │
│   116 │   │   else:                                                                              │
│ ❱ 117 │   │   │   compiled_gm = compiler_fn(gm, example_inputs)                                  │
│   118 │   │                                                                                      │
│   119 │   │   return compiled_gm                                                                 │
│   120                                                                                            │
│                                                                                                  │
│ /root/project/kernl/src/kernl/model_optimization.py:29 in _compiler                         │
│                                                                                                  │
│   26 # https://github.com/pytorch/torchdynamo/issues/1816                                        │
│   27 def _compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):                │
│   28 │   dynamo_backend_ofi(gm)                                                                  │
│ ❱ 29 │   return cuda_graphs_wrapper(gm, example_inputs)                                          │
│   30                                                                                             │
│   31                                                                                             │
│   32 def optimize_model(model: PreTrainedModel) -> None:                                         │
│                                                                                                  │
│ /root/project/kernl/src/kernl/optimizer/cuda_graph.py:116 in cuda_graphs_wrapper            │
│                                                                                                  │
│   113 │   if not any(isinstance(inp, FakeTensor) for inp in inputs):                             │
│   114 │   │   inputs = prepare_inputs(inputs=inputs, pools=static_inputs_pool)                   │
│   115 │   │   model(*inputs)  # additional warmup needed when input is mutated by some kernel    │
│ ❱ 116 │   │   f = cudagraphify_impl(                                                             │
│   117 │   │   │   model=lambda args: model(*args), inputs=inputs, static_input_idxs=tuple(rang   │
│   118 │   │   )                                                                                  │
│   119 │   │   return lambda *args: f(prepare_inputs(inputs=args, pools=static_inputs_pool))      │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/_inductor/compile_fx.py:818 in cudagraphify_impl   │
│                                                                                                  │
│    815 │                                                                                         │
│    816 │   # record                                                                              │
│    817 │   graph = torch.cuda.CUDAGraph()                                                        │
│ ❱  818 │   with torch.cuda.graph(graph, stream=stream, capture_error_mode="thread_local"):       │
│    819 │   │   static_outputs = model(list(static_inputs))                                       │
│    820 │   if not isinstance(static_outputs, (list, tuple)):                                     │
│    821 │   │   static_outputs = (static_outputs,)                                                │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/cuda/graphs.py:193 in __exit__                     │
│                                                                                                  │
│   190 │   │   )                                                                                  │
│   191 │                                                                                          │
│   192 │   def __exit__(self, exc_type, exc_value, traceback):                                    │
│ ❱ 193 │   │   self.cuda_graph.capture_end()                                                      │
│   194 │   │   self.stream_ctx.__exit__(exc_type, exc_value, traceback)                           │
│   195 │   │   # returning None should propagate exceptions from either capture_end or stream_c   │
│   196                                                                                            │
│                                                                                                  │
│ /usr/local/lib/python3.11/dist-packages/torch/cuda/graphs.py:84 in capture_end                   │
│                                                                                                  │
│    81 │   │   Use :class:`~torch.cuda.graph` or :func:`~torch.cuda.make_graphed_callables`,      │
│    82 │   │   which call ``capture_end`` internally.                                             │
│    83 │   │   """                                                                                │
│ ❱  84 │   │   super().capture_end()                                                              │
│    85 │                                                                                          │
│    86 │   def replay(self):                                                                      │
│    87 │   │   r"""                                                                               │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
BackendCompilerFailed: backend='_compiler' raised:
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Your environment

  • Operating system and version: Ubuntu 20.04.6 LTS
  • Python version: Python 3.11.6, torch==2.2.0.dev20231109+cu118, pytorch-triton==2.1.0+6e4932cda8
  • Python package manager: pip 23.3.1
  • Kernl commit: 91e2cd9

Self-service

  • I would be willing to help fix this bug myself.

Code of Conduct

  • I agree to follow this project's Code of Conduct
@wilson97
Copy link

afaik around pytorch 2.1 ish there was a change made to how pytorch represents the computation graph, now they use Sympy. I don't think kernl has been updated to accept that new computation graph format. However since you are using a common huggingface model, torch.compile, deepspeed, or BetterTransformers will probably work.

@CorentinJ
Copy link
Author

I will try using the recommended python & pytorch version and report.

For the record:

  • Torch.compile is only a slowdown for GPT2
  • Deepspeed works well but is buggy and poorly documented
  • BetterTransformer works but is not enough of a speedup compared to DeepSpeed, ORT or TRT

@wilson97
Copy link

thats a bit weird - torch.compile (with dynamic=True) gives a 1.6x speedup for me (google colab A100). Certainly not a slowdown. What is your hardware?

@wilson97
Copy link

@CorentinJ Also maybe unrelated I noticed you work at resemble AI. Are you trying to make tortoise-tts go faster by trying to make HF GPT2 faster for the autoregressive part of tortoise-tts? I'm working on a similar model to tortoise-tts and have it deployed: https://voicegen.org/. Maybe we can compare notes and help each other out. My email is wilson97@gmail.com if you're interested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants