-
Notifications
You must be signed in to change notification settings - Fork 30
Feat: solve TSP using branch and bound method #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
746fd19
75dc99d
a602381
bb52f58
e4f5eb3
a49fa65
859cb9e
e27e7f4
7e16b9d
bf9622b
b790c1e
bf2020a
38454e5
c6d6d88
e65a171
a1a913b
a4020a2
2b0002d
b5b0800
5203f0f
54a1b42
84a64f5
5830804
92f303a
150137a
866b9d8
9d24a3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,3 +142,6 @@ dmypy.json | |
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# Pycharm IDE | ||
.idea/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .node import Node # noqa: F401 | ||
from .priority_queue import PriorityQueue # noqa: F401 | ||
from .solver import solve_tsp_branch_and_bound # noqa: F401 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
@dataclass | ||
class Node: | ||
""" | ||
Represents a node in the search tree for the Traveling Salesperson Problem. | ||
|
||
Attributes | ||
---------- | ||
level : int | ||
luanleonardo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The level of the node in the search tree. | ||
index : int | ||
The index of the current city in the path. | ||
path : List[int] | ||
The list of city indices visited so far. | ||
cost : int | ||
The total cost of the path up to this node. | ||
cost_matrix : numpy.ndarray | ||
The cost matrix representing the distances between cities. | ||
|
||
Methods | ||
------- | ||
compute_reduced_matrix(matrix: numpy.ndarray) -> Tuple[numpy.ndarray, int]: | ||
Compute the reduced matrix and the cost of reducing it. | ||
from_cost_matrix(cost_matrix: numpy.ndarray) -> Node: | ||
Create a Node object from a given cost matrix. | ||
from_parent(parent: Node, index: int) -> Node: | ||
Create a new Node object based on a parent node and a city index. | ||
""" | ||
|
||
level: int | ||
index: int | ||
path: List[int] | ||
cost: int | ||
cost_matrix: np.ndarray | ||
|
||
@staticmethod | ||
def compute_reduced_matrix(matrix: np.ndarray) -> Tuple[np.ndarray, int]: | ||
""" | ||
Compute the reduced matrix and the cost of reducing it. | ||
|
||
Parameters | ||
---------- | ||
matrix : numpy.ndarray | ||
The cost matrix to compute the reductions. | ||
|
||
Returns | ||
------- | ||
Tuple[numpy.ndarray, int] | ||
A tuple containing the reduced matrix and the total | ||
cost of reductions. | ||
""" | ||
inf = np.iinfo(matrix.dtype).max | ||
mask = matrix != inf | ||
reduced_matrix = np.copy(matrix) | ||
|
||
min_rows = np.min(reduced_matrix, axis=1, keepdims=True) | ||
min_rows[min_rows == inf] = 0 | ||
if np.any(min_rows != 0): | ||
reduced_matrix = np.where( | ||
mask, reduced_matrix - min_rows, reduced_matrix | ||
) | ||
|
||
min_cols = np.min(reduced_matrix, axis=0, keepdims=True) | ||
min_cols[min_cols == inf] = 0 | ||
if np.any(min_cols != 0): | ||
reduced_matrix = np.where( | ||
mask, reduced_matrix - min_cols, reduced_matrix | ||
) | ||
|
||
return reduced_matrix, min_rows.sum() + min_cols.sum() | ||
|
||
@classmethod | ||
def from_cost_matrix(cls, cost_matrix: np.ndarray) -> Node: | ||
""" | ||
Create a Node object from a given cost matrix. | ||
|
||
Parameters | ||
---------- | ||
cost_matrix : numpy.ndarray | ||
The cost matrix representing the distances between cities. | ||
|
||
Returns | ||
------- | ||
Node | ||
A new Node object initialized with the reduced cost matrix. | ||
""" | ||
_cost_matrix, _cost = cls.compute_reduced_matrix(matrix=cost_matrix) | ||
return cls( | ||
level=0, | ||
index=0, | ||
path=[0], | ||
cost=_cost, | ||
cost_matrix=_cost_matrix, | ||
) | ||
|
||
@classmethod | ||
def from_parent(cls, parent: Node, index: int) -> Node: | ||
""" | ||
Create a new Node object based on a parent node and a city index. | ||
|
||
Parameters | ||
---------- | ||
parent : Node | ||
The parent node. | ||
index : int | ||
The index of the new city to be added to the path. | ||
|
||
Returns | ||
------- | ||
Node | ||
A new Node object with the updated path and cost. | ||
""" | ||
matrix = np.copy(parent.cost_matrix) | ||
inf = np.iinfo(matrix.dtype).max | ||
matrix[parent.index, :] = inf | ||
matrix[:, index] = inf | ||
matrix[index][0] = inf | ||
_cost_matrix, _cost = cls.compute_reduced_matrix(matrix=matrix) | ||
return cls( | ||
level=parent.level + 1, | ||
index=index, | ||
path=parent.path[:] + [index], | ||
cost=( | ||
parent.cost + _cost + parent.cost_matrix[parent.index][index] | ||
), | ||
cost_matrix=_cost_matrix, | ||
) | ||
|
||
def __lt__(self: Node, other: Node): | ||
""" | ||
Compare two Node objects based on their costs. | ||
|
||
Parameters | ||
---------- | ||
other : Node | ||
The other Node object to compare with. | ||
|
||
Returns | ||
------- | ||
bool | ||
True if this Node's cost is less than the other Node's | ||
cost, False otherwise. | ||
""" | ||
return self.cost < other.cost |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from dataclasses import dataclass, field | ||
from heapq import heappop, heappush | ||
from typing import List | ||
|
||
from python_tsp.exact.branch_and_bound import Node | ||
|
||
|
||
@dataclass | ||
class PriorityQueue: | ||
""" | ||
A priority queue implementation using a binary heap | ||
for efficient element retrieval. | ||
|
||
Attributes | ||
---------- | ||
_container : List[Node] | ||
The list that holds the elements in the priority queue. | ||
|
||
Methods | ||
------- | ||
empty() -> bool: | ||
Check if the priority queue is empty. | ||
push(item: Node) -> None: | ||
Push an item into the priority queue. | ||
pop() -> Node: | ||
Pop the item with the highest priority from the priority queue. | ||
""" | ||
|
||
_container: List[Node] = field(default_factory=list) | ||
|
||
@property | ||
def empty(self) -> bool: | ||
""" | ||
Check if the priority queue is empty. | ||
|
||
Returns | ||
------- | ||
bool | ||
True if the priority queue is empty, False otherwise. | ||
""" | ||
return not self._container | ||
|
||
def push(self, item: Node) -> None: | ||
""" | ||
Push an item into the priority queue. | ||
|
||
Parameters | ||
---------- | ||
item : Node | ||
The item to be pushed into the priority queue. | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
heappush(self._container, item) | ||
|
||
def pop(self) -> Node: | ||
""" | ||
Pop the item with the highest priority from the priority queue. | ||
|
||
Returns | ||
------- | ||
Node | ||
The node with the highest priority. | ||
""" | ||
return heappop(self._container) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
|
||
from python_tsp.exact.branch_and_bound import Node, PriorityQueue | ||
|
||
|
||
def solve_tsp_branch_and_bound( | ||
distance_matrix: np.ndarray, | ||
) -> Tuple[List[int], float]: | ||
""" | ||
Solve the Traveling Salesperson Problem (TSP) using | ||
the Branch and Bound algorithm. | ||
|
||
Parameters | ||
---------- | ||
distance_matrix : numpy.ndarray | ||
The distance matrix representing the distances between cities. | ||
|
||
Returns | ||
------- | ||
Tuple[List[int], float] | ||
A tuple containing the optimal path (list of city indices) and its | ||
total cost. If the TSP cannot be solved, an empty path and cost of 0 | ||
will be returned. | ||
""" | ||
num_cities = len(distance_matrix) | ||
cost_matrix = np.copy(distance_matrix) | ||
inf = np.iinfo(cost_matrix.dtype).max | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I can get from the documentation, this method only works for integers. It is mentioned in the README that the solvers in this library can handle matrices with floats, but I think with this line the function would fail if the distance matrix is not of integer type. Can your code work floating point numbers as well? If so, maybe replacing this If not, tell me and we can figure something out. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally right @fillipe-gsm, it only worked for integers, I've fixed this behavior and now it works for distance matrix with float numbers too. |
||
np.fill_diagonal(cost_matrix, inf) | ||
|
||
root = Node.from_cost_matrix(cost_matrix=cost_matrix) | ||
pq = PriorityQueue([root]) | ||
|
||
while not pq.empty: | ||
min_node = pq.pop() | ||
|
||
if min_node.level == num_cities - 1: | ||
return min_node.path, min_node.cost | ||
|
||
for index in range(num_cities): | ||
is_live_node = min_node.cost_matrix[min_node.index][index] != inf | ||
if is_live_node: | ||
live_node = Node.from_parent(parent=min_node, index=index) | ||
pq.push(live_node) | ||
|
||
return [], 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be a behavior unique to this solver. All of the other solvers will return a valid permutation anyway, even it its cost is extremely large. I think it's okay in this case, but instead of returning 0 as objective value I think one of the two would be better:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fillipe-gsm I prefer to return
|
Uh oh!
There was an error while loading. Please reload this page.