Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions mellea/helpers/event_loop_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Helper for event loop management. Allows consistently running async generate requests in sync code."""

import asyncio
import os
import threading
from collections.abc import Coroutine
from typing import Any, TypeVar
Expand All @@ -13,19 +14,33 @@
class _EventLoopHandler:
"""A class that handles the event loop for Mellea code. Do not directly instantiate this. Use `_run_async_in_thread`."""

def __init__(self) -> None:
"""Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea.

Do not instantiate this class. Rely on the exported `_run_async_in_thread` function.
"""
def _event_loop_setup(self):
"""Sets up the event loop and thread."""
# This code lives in a helper function since both __init__ and _reinit_if_forked
# will need to use it.
self._pid = os.getpid() # Store the pid in case users fork this process.
self._event_loop = asyncio.new_event_loop()
self._thread: threading.Thread = threading.Thread( # type: ignore[annotation-unchecked]
target=self._event_loop.run_forever,
daemon=True, # type: ignore
)
self._thread.start()

def __del__(self) -> None:
def __init__(self):
"""Instantiates an EventLoopHandler. Used to ensure consistency when calling async code from sync code in Mellea.

Do not instantiate this class. Rely on the exported `_run_async_in_thread` function.
"""
self._event_loop_setup()

def _reinit_if_forked(self) -> None:
"""Reinitialize the event loop and thread if we're in a forked child to prevent hanging on awaited tasks."""
if os.getpid() != self._pid:
# If the process has been forked, reset the event loop and thread.
# Don't cleanup the parent's objects.
self._event_loop_setup()

def __del__(self):
"""Delete the event loop handler."""
self._close_event_loop()

Expand Down Expand Up @@ -55,6 +70,7 @@ async def finalize_tasks() -> None:

def __call__(self, co: Coroutine[Any, Any, R]) -> R:
"""Runs the coroutine in the event loop."""
self._reinit_if_forked()
if self._event_loop == get_current_event_loop():
# If this gets called from the same event loop, launch in a separate thread to prevent blocking.
return _EventLoopHandler()(co)
Expand Down
32 changes: 32 additions & 0 deletions test/helpers/test_event_loop_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import multiprocessing

import pytest

import mellea.helpers.event_loop_helper as elh
Expand Down Expand Up @@ -32,6 +34,36 @@ async def testing() -> int:
assert elh.__event_loop_handler is not None


def test_event_loop_handler_with_forking():
"""Importing mellea before fork must not crash the child process."""

ctx = multiprocessing.get_context("fork")

def child():
import mellea.helpers.event_loop_helper as elh
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to reimport ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not but it ensures that the singleton patten works across forks.


async def hello():
return 42

result = elh._run_async_in_thread(hello())
assert result == 42

p = ctx.Process(target=child)

try:
p.start()
p.join(timeout=15)
assert p.exitcode == 0, (
f"Child process failed after fork (exit code: {p.exitcode if p.exitcode is not None else 'timed out'})"
)

finally:
# Make sure we always clean up the process.
if p.is_alive():
p.kill()
p.join(timeout=15)


if __name__ == "__main__":
import pytest

Expand Down
Loading