Skip to content

Commit 7975e0a

Browse files
committed
✨ Allow continuation requests
- If no incidents are included in a solution request, treat it as a continuation of the prior request. We reset the git state to the last snapshot before return, and continue working on the queue - This does not work if multiple requests are hitting the server, we'd need to do additional work to support that Signed-off-by: Fabian von Feilitzsch <[email protected]>
1 parent 0abd4fc commit 7975e0a

File tree

2 files changed

+87
-61
lines changed

2 files changed

+87
-61
lines changed

kai/reactive_codeplanner/vfs/git_vfs.py

+4
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ def diff(self, other: "RepoContextSnapshot") -> tuple[int, str, str]:
233233

234234

235235
class RepoContextManager:
236+
237+
last_snapshot_before_reset = None
238+
236239
def __init__(
237240
self,
238241
project_root: Path,
@@ -292,6 +295,7 @@ def reset(self, snapshot: Optional[RepoContextSnapshot] = None) -> None:
292295
Resets the repository to the given snapshot. If no snapshot is provided,
293296
reset the repo to the current snapshot.
294297
"""
298+
self.last_snapshot_before_reset = self.snapshot
295299
if snapshot is not None:
296300
self.snapshot = snapshot
297301

kai/rpc_server/server.py

+83-61
Original file line numberDiff line numberDiff line change
@@ -503,72 +503,80 @@ def simple_chat_message(msg: str) -> None:
503503
server.send_response(id=id, error=ERROR_NOT_INITIALIZED)
504504
return
505505

506-
# Get a snapshot of the current state of the repo so we can reset it
507-
# later
508-
app.rcm.commit(
509-
f"get_codeplan_agent_solution. id: {id}", run_reflection_agent=False
510-
)
511-
agent_solution_snapshot = app.rcm.snapshot
506+
# If incidents are provided, we are on a new run
507+
# If it is empty, this is a request for continuation
508+
if params.incidents:
509+
# Get a snapshot of the current state of the repo so we can reset it
510+
# later
511+
app.rcm.commit(
512+
f"get_codeplan_agent_solution. id: {id}", run_reflection_agent=False
513+
)
514+
agent_solution_snapshot = app.rcm.snapshot
512515

513-
app.config = cast(KaiRpcApplicationConfig, app.config)
516+
app.config = cast(KaiRpcApplicationConfig, app.config)
514517

515-
# Data for AnalyzerRuleViolation should probably take an ExtendedIncident
516-
seed_tasks: list[Task] = []
518+
# Data for AnalyzerRuleViolation should probably take an ExtendedIncident
519+
seed_tasks: list[Task] = []
517520

518-
params.incidents.sort()
519-
grouped_incidents_by_files = [
520-
list(g) for _, g in groupby(params.incidents, key=attrgetter("uri"))
521-
]
522-
for incidents in grouped_incidents_by_files:
523-
524-
# group incidents by violation
525-
grouped_violations = [
526-
list(g) for _, g in groupby(incidents, key=attrgetter("violation_name"))
521+
params.incidents.sort()
522+
grouped_incidents_by_files = [
523+
list(g) for _, g in groupby(params.incidents, key=attrgetter("uri"))
527524
]
528-
for violation_incidents in grouped_violations:
529-
530-
incident_base = violation_incidents[0]
531-
uri_path = urlparse(incident_base.uri).path
532-
if platform.system() == "Windows":
533-
uri_path = uri_path.removeprefix("/")
534-
535-
class_to_use = AnalyzerRuleViolation
536-
if "pom.xml" in incident_base.uri:
537-
class_to_use = AnalyzerDependencyRuleViolation
538-
539-
validation_error = class_to_use(
540-
file=str(Path(uri_path).absolute()),
541-
violation=Violation(
542-
id=incident_base.violation_name or "",
543-
description=incident_base.violation_description or "",
544-
category=incident_base.violation_category,
545-
labels=incident_base.violation_labels,
546-
),
547-
ruleset=RuleSet(
548-
name=incident_base.ruleset_name,
549-
description=incident_base.ruleset_description or "",
550-
),
551-
line=0,
552-
column=-1,
553-
message="",
554-
incidents=[],
555-
)
556-
validation_error.incidents = []
557-
for i in violation_incidents:
558-
if i.line_number < 0:
559-
continue
560-
validation_error.incidents.append(Incident(**i.model_dump()))
561-
562-
if validation_error.incidents:
563-
app.log.log(
564-
TRACE,
565-
"seed_tasks adding to list: %s -- incident_messages: %s",
566-
validation_error,
567-
validation_error.incident_message,
525+
for incidents in grouped_incidents_by_files:
526+
527+
# group incidents by violation
528+
grouped_violations = [
529+
list(g)
530+
for _, g in groupby(incidents, key=attrgetter("violation_name"))
531+
]
532+
for violation_incidents in grouped_violations:
533+
534+
incident_base = violation_incidents[0]
535+
uri_path = urlparse(incident_base.uri).path
536+
if platform.system() == "Windows":
537+
uri_path = uri_path.removeprefix("/")
538+
539+
class_to_use = AnalyzerRuleViolation
540+
if "pom.xml" in incident_base.uri:
541+
class_to_use = AnalyzerDependencyRuleViolation
542+
543+
validation_error = class_to_use(
544+
file=str(Path(uri_path).absolute()),
545+
violation=Violation(
546+
id=incident_base.violation_name or "",
547+
description=incident_base.violation_description or "",
548+
category=incident_base.violation_category,
549+
labels=incident_base.violation_labels,
550+
),
551+
ruleset=RuleSet(
552+
name=incident_base.ruleset_name,
553+
description=incident_base.ruleset_description or "",
554+
),
555+
line=0,
556+
column=-1,
557+
message="",
558+
incidents=[],
568559
)
569-
seed_tasks.append(validation_error)
560+
validation_error.incidents = []
561+
for i in violation_incidents:
562+
if i.line_number < 0:
563+
continue
564+
validation_error.incidents.append(Incident(**i.model_dump()))
565+
566+
if validation_error.incidents:
567+
app.log.log(
568+
TRACE,
569+
"seed_tasks adding to list: %s -- incident_messages: %s",
570+
validation_error,
571+
validation_error.incident_message,
572+
)
573+
seed_tasks.append(validation_error)
570574

571-
app.task_manager.set_seed_tasks(*seed_tasks)
575+
app.task_manager.set_seed_tasks(*seed_tasks)
576+
else:
577+
# reset to the last git state
578+
if app.rcm.last_snapshot_before_reset is not None:
579+
app.rcm.reset(app.rcm.last_snapshot_before_reset)
572580

573581
app.log.info(
574582
f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
@@ -585,11 +593,13 @@ class OverallResult(TypedDict):
585593
encountered_errors: list[str]
586594
modified_files: list[str]
587595
diff: str
596+
remaining_issues: list[Task]
588597

589598
overall_result: OverallResult = {
590599
"encountered_errors": [],
591600
"modified_files": [],
592601
"diff": "",
602+
"remaining_issues": [],
593603
}
594604

595605
# get the solved tasks set
@@ -635,7 +645,7 @@ class OverallResult(TypedDict):
635645

636646
app.log.debug(result)
637647

638-
if seed_tasks:
648+
if app.task_manager.priority_queue.task_stacks.get(0):
639649
# If we have seed tasks, we are fixing a set of issues,
640650
# Lets only focus on this when showing the queue.
641651
all_tasks = app.task_manager.priority_queue.all_tasks()
@@ -679,6 +689,18 @@ class OverallResult(TypedDict):
679689

680690
simple_chat_message("Running validators...")
681691

692+
relevant_task_priority = max(app.task_manager.priority_queue.task_stacks.keys())
693+
if params.max_priority is not None:
694+
relevant_task_priority = params.max_priority
695+
696+
overall_result["remaining_issues"] = [
697+
task
698+
for tasks in [
699+
app.task_manager.priority_queue.task_stacks.get(i, [])
700+
for i in range(relevant_task_priority)
701+
]
702+
for task in tasks
703+
]
682704
# after we have completed all the tasks, we should show what has been accomplished for this particular solution
683705
app.log.debug("QUEUE_STATE_END_OF_CODE_PLAN: SUCCESSFUL TASKS: START")
684706
for task in app.task_manager.processed_tasks - initial_solved_tasks:

0 commit comments

Comments
 (0)