Skip to content

WIP: Structure tool and prompt-service error handling improvments (for SDK v0.71) #1277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 57 additions & 67 deletions backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,25 +952,24 @@ def _fetch_response(
TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value,
}

responder = PromptTool(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
)
include_metadata = {TSPKeys.INCLUDE_METADATA: True}
headers = {Common.X_REQUEST_ID: StateStore.get(Common.REQUEST_ID)}
answer = responder.answer_prompt(
payload=payload, params=include_metadata, headers=headers
)
if answer["status"] == "ERROR":
error_message = answer.get("error", "")
try:
responder = PromptTool(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
request_id=StateStore.get(Common.REQUEST_ID),
)
params = {TSPKeys.INCLUDE_METADATA: True}
return responder.answer_prompt(payload=payload, params=params)
except SdkError as e:
msg = str(e)
if e.actual_err and hasattr(e.actual_err, "response"):
msg = e.actual_err.response.json().get("error", str(e))
raise AnswerFetchError(
"Error while fetching response for "
f"'{prompt.prompt_key}' with '{doc_name}'. {error_message}",
status_code=int(answer.get("status_code")),
f"'{prompt.prompt_key}' with '{doc_name}'. {msg}",
status_code=int(e.status_code or 500),
)
output_response = json.loads(answer["structure_output"])
return output_response

@staticmethod
def fetch_table_settings_if_enabled(
Expand Down Expand Up @@ -1101,22 +1100,22 @@ def dynamic_indexer(

util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id)

responder = PromptTool(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
)
headers = {Common.X_REQUEST_ID: StateStore.get(Common.REQUEST_ID)}
response = responder.index(payload=payload, headers=headers)

status_code = response.get("status_code")
if status_code == 200:
doc_id = json.loads(response.get("structure_output")).get("doc_id")
else:
error_message = f"Failed to index '{filename}'. " + response.get(
"error", ""
try:
responder = PromptTool(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
request_id=StateStore.get(Common.REQUEST_ID),
)
doc_id = responder.index(payload=payload)
except SdkError as e:
msg = str(e)
if e.actual_err and hasattr(e.actual_err, "response"):
msg = e.actual_err.response.json().get("error", str(e))
raise IndexingAPIError(
f"Failed to index '{filename}'. {msg}",
status_code=int(e.status_code or 500),
)
raise IndexingAPIError(error_message)

PromptStudioIndexHelper.handle_index_manager(
document_id=document_id,
Expand All @@ -1128,15 +1127,19 @@ def dynamic_indexer(
)
return {"status": IndexingStatus.COMPLETED_STATUS.value, "output": doc_id}
except (IndexingError, IndexingAPIError, SdkError) as e:
logger.error(f"Indexing failed : {e} ", stack_info=True, exc_info=True)
doc_name = os.path.split(file_path)[1]
msg = str(e)
if isinstance(e, SdkError) and hasattr(e.actual_err, "response"):
msg = e.actual_err.response.json().get("error", str(e))

msg = f"Error while indexing '{filename}'. {msg}"
logger.error(msg, stack_info=True, exc_info=True)
PromptStudioHelper._publish_log(
{"tool_id": tool_id, "run_id": run_id, "doc_name": doc_name},
{"tool_id": tool_id, "run_id": run_id, "doc_name": filename},
LogLevels.ERROR,
LogLevels.RUN,
f"Indexing failed : {e}",
msg,
)
raise IndexingAPIError(f"Error while indexing '{doc_name}'. {str(e)}") from e
raise IndexingAPIError(msg) from e

@staticmethod
def _fetch_single_pass_response(
Expand Down Expand Up @@ -1254,22 +1257,10 @@ def _fetch_single_pass_response(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
request_id=StateStore.get(Common.REQUEST_ID),
)
include_metadata = {TSPKeys.INCLUDE_METADATA: True}
headers = {Common.X_REQUEST_ID: StateStore.get(Common.REQUEST_ID)}
answer = responder.single_pass_extraction(
payload=payload,
params=include_metadata,
headers=headers,
)
if answer["status"] == "ERROR":
error_message = answer.get("error", None)
logger.info(f"{str(answer)}")
raise AnswerFetchError(
f"Error while fetching response for prompt(s). {error_message}"
)
output_response = json.loads(answer["structure_output"])
return output_response
params = {TSPKeys.INCLUDE_METADATA: True}
return responder.single_pass_extraction(payload=payload, params=params)

@staticmethod
def get_tool_from_tool_id(tool_id: str) -> CustomTool | None:
Expand Down Expand Up @@ -1334,27 +1325,26 @@ def dynamic_extractor(

util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id)

responder = PromptTool(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
)
headers = {Common.X_REQUEST_ID: StateStore.get(Common.REQUEST_ID)}
response = responder.extract(payload=payload, headers=headers)
status_code = response.get("status_code")
if status_code == 200:
response_data = response.get("structure_output")
structure_output = json.loads(response_data)
extracted_text = structure_output.get("extracted_text")
try:
responder = PromptTool(
tool=util,
prompt_host=settings.PROMPT_HOST,
prompt_port=settings.PROMPT_PORT,
request_id=StateStore.get(Common.REQUEST_ID),
)
extracted_text = responder.extract(payload=payload)
PromptStudioIndexHelper.mark_extraction_status(
document_id=document_id,
profile_manager=profile_manager,
doc_id=doc_id,
)
else:
error_message = f"Failed to extract '{filename}'. " + response.get(
"error", ""
except SdkError as e:
msg = str(e)
if e.actual_err and hasattr(e.actual_err, "response"):
msg = e.actual_err.response.json().get("error", str(e))
raise ExtractionAPIError(
f"Failed to extract '{filename}'. {msg}",
status_code=int(e.status_code or 500),
)
raise ExtractionAPIError(error_message)

return extracted_text
2 changes: 1 addition & 1 deletion tools/structure/src/config/properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"schemaVersion": "0.0.1",
"displayName": "Structure Tool",
"functionName": "structure_tool",
"toolVersion": "0.0.76",
"toolVersion": "0.0.77",
"description": "This is a template tool which can answer set of input prompts designed in the Prompt Studio",
"input": {
"description": "File that needs to be indexed and parsed for answers"
Expand Down
25 changes: 13 additions & 12 deletions tools/structure/src/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
import json
import logging
from typing import Any

from constants import IndexingConstants as IKeys
Expand All @@ -8,6 +8,8 @@
from unstract.sdk.prompt import PromptTool
from unstract.sdk.tool.base import BaseTool

logger = logging.getLogger(__name__)


class StructureToolHelper:
@staticmethod
Expand All @@ -22,7 +24,6 @@ def dynamic_extraction(
execution_run_data_folder: str,
) -> str:
x2text = tool_settings[SettingsKeys.X2TEXT_ADAPTER]
tool.stream_log(f"Extracting text from {file_path} into {extract_file_path}")
payload = {
IKeys.X2TEXT_INSTANCE_ID: x2text,
IKeys.FILE_PATH: file_path,
Expand All @@ -36,18 +37,15 @@ def dynamic_extraction(
IKeys.EXECUTION_DATA_DIR: str(execution_run_data_folder),
}

tool.stream_log(f"Payload constructed : {payload}")
responder = PromptTool(
logger.info(f"Prompt service payload for text extraction:\n{payload}")

prompt_tool = PromptTool(
tool=tool,
prompt_host=tool.get_env_or_die(SettingsKeys.PROMPT_HOST),
prompt_port=tool.get_env_or_die(SettingsKeys.PROMPT_PORT),
request_id=run_id,
)
tool.stream_log(f"responder : {responder}")
response = responder.extract(payload=payload)
response_data = response.get("structure_output")
structure_output = json.loads(response_data)
extracted_text = structure_output.get("extracted_text")
return extracted_text
return prompt_tool.extract(payload=payload)

@staticmethod
def dynamic_indexing(
Expand Down Expand Up @@ -87,13 +85,16 @@ def dynamic_indexing(
IKeys.EXTRACTED_TEXT: extracted_text,
}

sensitive_keys = [IKeys.EXTRACTED_TEXT]
payload_to_log = {k: v for k, v in payload.items() if k not in sensitive_keys}
logger.info(f"Prompt service payload for indexing:\n{payload_to_log}")
responder = PromptTool(
tool=tool,
prompt_host=tool.get_env_or_die(SettingsKeys.PROMPT_HOST),
prompt_port=tool.get_env_or_die(SettingsKeys.PROMPT_PORT),
request_id=run_id,
)
doc_id = responder.index(payload=payload)
return doc_id
return responder.index(payload=payload)

@staticmethod
def elapsed_time(start_time) -> float:
Expand Down
Loading