Skip to content

fix: dialogue_count incorrect in chatflow when there's... #11175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
Expand All @@ -33,6 +34,8 @@


class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int

def generate(
self,
app_model: App,
Expand Down Expand Up @@ -211,6 +214,9 @@ def _generate(
db.session.commit()
db.session.refresh(conversation)

# get conversation dialogue count
self._dialogue_count = get_thread_messages_length(conversation.id)

# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
Expand Down Expand Up @@ -281,6 +287,7 @@ def _generate_worker(
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
)

runner.run()
Expand Down Expand Up @@ -334,6 +341,7 @@ def _handle_advanced_chat_response(
message=message,
user=user,
stream=stream,
dialogue_count=self._dialogue_count,
)

try:
Expand Down
10 changes: 3 additions & 7 deletions api/core/app/apps/advanced_chat/app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def __init__(
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
dialogue_count: int,
) -> None:
super().__init__(queue_manager)

self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count

def run(self) -> None:
app_config = self.application_generate_entity.app_config
Expand Down Expand Up @@ -122,19 +124,13 @@ def run(self) -> None:

session.commit()

# Increment dialogue count.
self.conversation.dialogue_count += 1

conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()

# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
Expand Down
4 changes: 3 additions & 1 deletion api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
message: Message,
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
Expand All @@ -98,6 +99,7 @@ def __init__(
:param message: message
:param user: user
:param stream: stream
:param dialogue_count: dialogue count
"""
super().__init__(application_generate_entity, queue_manager, user, stream)

Expand All @@ -114,7 +116,7 @@ def __init__(
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
Expand Down
32 changes: 32 additions & 0 deletions api/core/prompt/utils/get_thread_messages_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message


def get_thread_messages_length(conversation_id: str) -> int:
"""
Get the number of thread messages based on the parent message id.
"""
# Fetch all messages related to the conversation
query = (
db.session.query(
Message.id,
Message.parent_message_id,
Message.answer,
)
.filter(
Message.conversation_id == conversation_id,
)
.order_by(Message.created_at.desc())
)

messages = query.all()

# Extract thread messages
thread_messages = extract_thread_messages(messages)

# Exclude the newly created message with an empty answer
if thread_messages and not thread_messages[0].answer:
thread_messages.pop(0)

return len(thread_messages)