Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ✨ Allow continuation requests #630

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions kai/reactive_codeplanner/vfs/git_vfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def diff(self, other: "RepoContextSnapshot") -> tuple[int, str, str]:


class RepoContextManager:

last_snapshot_before_reset = None

def __init__(
self,
project_root: Path,
Expand Down Expand Up @@ -292,6 +295,7 @@ def reset(self, snapshot: Optional[RepoContextSnapshot] = None) -> None:
Resets the repository to the given snapshot. If no snapshot is provided,
reset the repo to the current snapshot.
"""
self.last_snapshot_before_reset = self.snapshot
if snapshot is not None:
self.snapshot = snapshot

Expand Down
147 changes: 86 additions & 61 deletions kai/rpc_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,72 +503,80 @@ def simple_chat_message(msg: str) -> None:
server.send_response(id=id, error=ERROR_NOT_INITIALIZED)
return

# Get a snapshot of the current state of the repo so we can reset it
# later
app.rcm.commit(
f"get_codeplan_agent_solution. id: {id}", run_reflection_agent=False
)
agent_solution_snapshot = app.rcm.snapshot
# If incidents are provided, we are on a new run
# If it is empty, this is a request for continuation
if params.incidents:
# Get a snapshot of the current state of the repo so we can reset it
# later
app.rcm.commit(
f"get_codeplan_agent_solution. id: {id}", run_reflection_agent=False
)
agent_solution_snapshot = app.rcm.snapshot

app.config = cast(KaiRpcApplicationConfig, app.config)
app.config = cast(KaiRpcApplicationConfig, app.config)

# Data for AnalyzerRuleViolation should probably take an ExtendedIncident
seed_tasks: list[Task] = []
# Data for AnalyzerRuleViolation should probably take an ExtendedIncident
seed_tasks: list[Task] = []

params.incidents.sort()
grouped_incidents_by_files = [
list(g) for _, g in groupby(params.incidents, key=attrgetter("uri"))
]
for incidents in grouped_incidents_by_files:

# group incidents by violation
grouped_violations = [
list(g) for _, g in groupby(incidents, key=attrgetter("violation_name"))
params.incidents.sort()
grouped_incidents_by_files = [
list(g) for _, g in groupby(params.incidents, key=attrgetter("uri"))
]
for violation_incidents in grouped_violations:

incident_base = violation_incidents[0]
uri_path = urlparse(incident_base.uri).path
if platform.system() == "Windows":
uri_path = uri_path.removeprefix("/")

class_to_use = AnalyzerRuleViolation
if "pom.xml" in incident_base.uri:
class_to_use = AnalyzerDependencyRuleViolation

validation_error = class_to_use(
file=str(Path(uri_path).absolute()),
violation=Violation(
id=incident_base.violation_name or "",
description=incident_base.violation_description or "",
category=incident_base.violation_category,
labels=incident_base.violation_labels,
),
ruleset=RuleSet(
name=incident_base.ruleset_name,
description=incident_base.ruleset_description or "",
),
line=0,
column=-1,
message="",
incidents=[],
)
validation_error.incidents = []
for i in violation_incidents:
if i.line_number < 0:
continue
validation_error.incidents.append(Incident(**i.model_dump()))

if validation_error.incidents:
app.log.log(
TRACE,
"seed_tasks adding to list: %s -- incident_messages: %s",
validation_error,
validation_error.incident_message,
for incidents in grouped_incidents_by_files:

# group incidents by violation
grouped_violations = [
list(g)
for _, g in groupby(incidents, key=attrgetter("violation_name"))
]
for violation_incidents in grouped_violations:

incident_base = violation_incidents[0]
uri_path = urlparse(incident_base.uri).path
if platform.system() == "Windows":
uri_path = uri_path.removeprefix("/")

class_to_use = AnalyzerRuleViolation
if "pom.xml" in incident_base.uri:
class_to_use = AnalyzerDependencyRuleViolation

validation_error = class_to_use(
file=str(Path(uri_path).absolute()),
violation=Violation(
id=incident_base.violation_name or "",
description=incident_base.violation_description or "",
category=incident_base.violation_category,
labels=incident_base.violation_labels,
),
ruleset=RuleSet(
name=incident_base.ruleset_name,
description=incident_base.ruleset_description or "",
),
line=0,
column=-1,
message="",
incidents=[],
)
seed_tasks.append(validation_error)
validation_error.incidents = []
for i in violation_incidents:
if i.line_number < 0:
continue
validation_error.incidents.append(Incident(**i.model_dump()))

if validation_error.incidents:
app.log.log(
TRACE,
"seed_tasks adding to list: %s -- incident_messages: %s",
validation_error,
validation_error.incident_message,
)
seed_tasks.append(validation_error)

app.task_manager.set_seed_tasks(*seed_tasks)
app.task_manager.set_seed_tasks(*seed_tasks)
else:
# reset to the last git state
if app.rcm.last_snapshot_before_reset is not None:
app.rcm.reset(app.rcm.last_snapshot_before_reset)

app.log.info(
f"Starting code plan loop with iterations: {params.max_iterations}, max depth: {params.max_depth}, and max priority: {params.max_priority}"
Expand All @@ -585,11 +593,13 @@ class OverallResult(TypedDict):
encountered_errors: list[str]
modified_files: list[str]
diff: str
remaining_issues: list[Task]

overall_result: OverallResult = {
"encountered_errors": [],
"modified_files": [],
"diff": "",
"remaining_issues": [],
}

# get the solved tasks set
Expand Down Expand Up @@ -635,7 +645,7 @@ class OverallResult(TypedDict):

app.log.debug(result)

if seed_tasks:
if app.task_manager.priority_queue.task_stacks.get(0):
# If we have seed tasks, we are fixing a set of issues,
# Lets only focus on this when showing the queue.
all_tasks = app.task_manager.priority_queue.all_tasks()
Expand Down Expand Up @@ -679,6 +689,21 @@ class OverallResult(TypedDict):

simple_chat_message("Running validators...")

relevant_task_priority = -1
stacks = app.task_manager.priority_queue.task_stacks.keys()
if params.max_priority is not None:
relevant_task_priority = params.max_priority
elif stacks:
relevant_task_priority = max(stacks)

overall_result["remaining_issues"] = [
task
for tasks in [
app.task_manager.priority_queue.task_stacks.get(i, [])
for i in range(relevant_task_priority + 1)
]
for task in tasks
]
# after we have completed all the tasks, we should show what has been accomplished for this particular solution
app.log.debug("QUEUE_STATE_END_OF_CODE_PLAN: SUCCESSFUL TASKS: START")
for task in app.task_manager.processed_tasks - initial_solved_tasks:
Expand Down
Loading