Skip to content

Commit 1209dfb

Browse files
committed
Simplified classification code
1 parent 12a9e48 commit 1209dfb

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

.github/workflows/python-package.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
22
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
33

4-
name: Python package
4+
name: Python Unit Tests, Lint, and Type Checks
55

66
on:
77
push:
@@ -39,3 +39,5 @@ jobs:
3939
- name: Run all unit tests in the /tests directory
4040
run: |
4141
python -m unittest discover -s tests
42+
- name: Check Type Hints with Pyright
43+
uses: jakebailey/[email protected]

KNN/knn.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
import csv
1717
from typing import Protocol, Self
18+
from collections import Counter
1819
import numpy as np
1920

2021

@@ -44,19 +45,13 @@ def _read_csv(self, file_path: str, has_header: bool) -> None:
4445

4546
# Find the k nearest neighbors of a given data point based on the distance method
4647
def nearest(self, k: int, data_point: DP) -> list[DP]:
47-
return sorted(self.data_points, key=lambda other: data_point.distance(other))[:k]
48+
return sorted(self.data_points, key=data_point.distance)[:k]
4849

4950
# Classify a data point based on the k nearest neighbors
5051
# Choose the kind with the most neighbors and return it
5152
def classify(self, k: int, data_point: DP) -> str:
5253
neighbors = self.nearest(k, data_point)
53-
kinds = {}
54-
for neighbor in neighbors:
55-
if neighbor.kind in kinds:
56-
kinds[neighbor.kind] += 1
57-
else:
58-
kinds[neighbor.kind] = 1
59-
return max(kinds, key=kinds.get) # type: ignore
54+
return Counter(neighbor.kind for neighbor in neighbors).most_common(1)[0][0]
6055

6156
# Predict a property of a data point based on the k nearest neighbors
6257
# Find the average of that property from the neighbors and return it

0 commit comments

Comments
 (0)