Skip to content

Commit df497de

Browse files
authored
Fix for circular import on ConversationValidator (All-Hands-AI#7583)
1 parent f12bf98 commit df497de

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

openhands/server/listen_socket.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
sio,
2424
)
2525
from openhands.storage.conversation.conversation_validator import (
26-
ConversationValidatorImpl,
26+
create_conversation_validator,
2727
)
2828

2929

@@ -38,7 +38,7 @@ async def connect(connection_id: str, environ):
3838
raise ConnectionRefusedError('No conversation_id in query params')
3939

4040
cookies_str = environ.get('HTTP_COOKIE', '')
41-
conversation_validator = ConversationValidatorImpl()
41+
conversation_validator = create_conversation_validator()
4242
user_id, github_user_id = await conversation_validator.validate(
4343
conversation_id, cookies_str
4444
)

openhands/storage/conversation/conversation_validator.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@ async def validate(self, conversation_id: str, cookies_str: str):
1010
return None, None
1111

1212

13-
conversation_validator_cls = os.environ.get(
14-
'OPENHANDS_CONVERSATION_VALIDATOR_CLS',
15-
'openhands.storage.conversation.conversation_validator.ConversationValidator',
16-
)
17-
ConversationValidatorImpl = get_impl(ConversationValidator, conversation_validator_cls)
13+
def create_conversation_validator():
14+
conversation_validator_cls = os.environ.get(
15+
'OPENHANDS_CONVERSATION_VALIDATOR_CLS',
16+
'openhands.storage.conversation.conversation_validator.ConversationValidator',
17+
)
18+
ConversationValidatorImpl = get_impl(
19+
ConversationValidator, conversation_validator_cls
20+
)
21+
return ConversationValidatorImpl()

openhands/utils/import_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
from functools import lru_cache
23
from typing import Type, TypeVar
34

45
T = TypeVar('T')
@@ -13,6 +14,7 @@ def import_from(qual_name: str):
1314
return result
1415

1516

17+
@lru_cache()
1618
def get_impl(cls: Type[T], impl_name: str | None) -> Type[T]:
1719
"""Import a named implementation of the specified class"""
1820
if impl_name is None:

0 commit comments

Comments
 (0)