Skip to content

Commit fdafbc6

Browse files
authored
Merge pull request #15 from commaai/robust-model
Add more robust model
2 parents 4dad510 + aaf974f commit fdafbc6

File tree

9 files changed

+44
-33
lines changed

9 files changed

+44
-33
lines changed

.github/workflows/main.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ jobs:
3131
3232
- name: Run Simple controller rollout
3333
run: |
34-
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data/00000.csv --controller simple
34+
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data/00000.csv --controller pid
3535
3636
- name: Run batch rollouts
3737
run: |
38-
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 20 --controller simple
38+
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 20 --controller pid
3939
4040
- name: Run report
4141
run: |
42-
python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 50 --test_controller open --baseline_controller simple
42+
python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 50 --test_controller zero --baseline_controller pid

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ bash ./download_dataset.sh
1616
pip install -r requirements.txt
1717
1818
# test this works
19-
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data/00000.csv --debug --controller simple
19+
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data/00000.csv --debug --controller pid
2020
```
2121

2222
There are some other scripts to help you get aggregate metrics:
2323
```
2424
# batch Metrics of a controller on lots of routes
25-
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 100 --controller simple
25+
python tinyphysics.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 100 --controller pid
2626
2727
# generate a report comparing two controllers
28-
python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 100 --test_controller simple --baseline_controller open
28+
python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 100 --test_controller pid --baseline_controller zero
2929
3030
```
3131
You can also use the notebook at [`experiment.ipynb`](https://github.com/commaai/controls_challenge/blob/master/experiment.ipynb) for exploration.
@@ -44,13 +44,13 @@ Each rollout will result in 2 costs:
4444

4545
- `jerk_cost`: $\dfrac{\Sigma((actual\\_lat\\_accel\_t - actual\\_lat\\_accel\_{t-1}) / \Delta t)^2}{steps - 1} * 100$
4646

47-
It is important to minimize both costs. `total_cost`: $(lataccel\\_cost *5) + jerk\\_cost$
47+
It is important to minimize both costs. `total_cost`: $(lataccel\\_cost * 50) + jerk\\_cost$
4848

4949
## Submission
5050
Run the following command, and submit `report.html` and your code to [this form](https://forms.gle/US88Hg7UR6bBuW3BA).
5151

5252
```
53-
python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 5000 --test_controller <insert your controller name> --baseline_controller simple
53+
python eval.py --model_path ./models/tinyphysics.onnx --data_path ./data --num_segs 5000 --test_controller <insert your controller name> --baseline_controller pid
5454
```
5555

5656
## Work at comma

controllers/pid.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from . import BaseController
2+
import numpy as np
3+
4+
class Controller(BaseController):
5+
"""
6+
A simple PID controller
7+
"""
8+
def __init__(self,):
9+
self.p = 0.3
10+
self.i = 0.05
11+
self.d = -0.1
12+
self.error_integral = 0
13+
self.prev_error = 0
14+
15+
def update(self, target_lataccel, current_lataccel, state, future_plan):
16+
error = (target_lataccel - current_lataccel)
17+
self.error_integral += error
18+
error_diff = error - self.prev_error
19+
self.prev_error = error
20+
return self.p * error + self.i * self.error_integral + self.d * error_diff

controllers/simple.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

controllers/open.py renamed to controllers/zero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
class Controller(BaseController):
55
"""
6-
An open-loop controller
6+
A controller that always outputs zero
77
"""
88
def update(self, target_lataccel, current_lataccel, state, future_plan):
9-
return target_lataccel
9+
return 0.0

eval.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def create_report(test, baseline, sample_rollouts, costs):
7070
parser.add_argument("--model_path", type=str, required=True)
7171
parser.add_argument("--data_path", type=str, required=True)
7272
parser.add_argument("--num_segs", type=int, default=100)
73-
parser.add_argument("--test_controller", default='simple', choices=available_controllers)
74-
parser.add_argument("--baseline_controller", default='simple', choices=available_controllers)
73+
parser.add_argument("--test_controller", default='pid', choices=available_controllers)
74+
parser.add_argument("--baseline_controller", default='pid', choices=available_controllers)
7575
args = parser.parse_args()
7676

7777
data_path = Path(args.data_path)
@@ -99,7 +99,7 @@ def create_report(test, baseline, sample_rollouts, costs):
9999
for controller_cat, controller_type in [('baseline', args.baseline_controller), ('test', args.test_controller)]:
100100
print(f"Running batch rollouts => {controller_cat} controller: {controller_type}")
101101
rollout_partial = partial(run_rollout, controller_type=controller_type, model_path=args.model_path, debug=False)
102-
results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16)
102+
results = process_map(rollout_partial, files[SAMPLE_ROLLOUTS:], max_workers=16, chunksize=10)
103103
costs += [{'controller': controller_cat, **result[0]} for result in results]
104104

105105
create_report(args.test_controller, args.baseline_controller, sample_rollouts, costs)

experiment.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"outputs": [],
88
"source": [
99
"from tinyphysics import TinyPhysicsModel, TinyPhysicsSimulator, CONTROL_START_IDX\n",
10-
"from controllers import simple\n",
10+
"from controllers import pid\n",
1111
"from matplotlib import pyplot as plt\n",
1212
"import seaborn as sns\n",
1313
"\n",
@@ -38,7 +38,7 @@
3838
"outputs": [],
3939
"source": [
4040
"model = TinyPhysicsModel(\"./models/tinyphysics.onnx\", debug=True)\n",
41-
"controller = simple.Controller()"
41+
"controller = pid.Controller()"
4242
]
4343
},
4444
{

models/tinyphysics.onnx

-71.2 KB
Binary file not shown.

tinyphysics.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
ACC_G = 9.81
2323
FPS = 10
2424
CONTROL_START_IDX = 100
25-
COST_END_IDX = 550
25+
COST_END_IDX = 500
2626
CONTEXT_LENGTH = 20
2727
VOCAB_SIZE = 1024
2828
LATACCEL_RANGE = [-5, 5]
2929
STEER_RANGE = [-2, 2]
3030
MAX_ACC_DELTA = 0.5
3131
DEL_T = 0.1
32-
LAT_ACCEL_COST_MULTIPLIER = 5.0
32+
LAT_ACCEL_COST_MULTIPLIER = 50.0
3333

3434
FUTURE_PLAN_STEPS = FPS * 5 # 5 secs
3535

@@ -86,7 +86,7 @@ def get_current_lataccel(self, sim_states: List[State], actions: List[float], pa
8686
'states': np.expand_dims(states, axis=0).astype(np.float32),
8787
'tokens': np.expand_dims(tokenized_actions, axis=0).astype(np.int64)
8888
}
89-
return self.tokenizer.decode(self.predict(input_data, temperature=1.))
89+
return self.tokenizer.decode(self.predict(input_data, temperature=0.8))
9090

9191

9292
class TinyPhysicsSimulator:
@@ -148,10 +148,10 @@ def get_state_target_futureplan(self, step_idx: int) -> Tuple[State, float]:
148148
State(roll_lataccel=state['roll_lataccel'], v_ego=state['v_ego'], a_ego=state['a_ego']),
149149
state['target_lataccel'],
150150
FuturePlan(
151-
lataccel=self.data['target_lataccel'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(),
152-
roll_lataccel=self.data['roll_lataccel'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(),
153-
v_ego=self.data['v_ego'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist(),
154-
a_ego=self.data['a_ego'].values[step_idx + 1 :step_idx + FUTURE_PLAN_STEPS].tolist()
151+
lataccel=self.data['target_lataccel'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(),
152+
roll_lataccel=self.data['roll_lataccel'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(),
153+
v_ego=self.data['v_ego'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist(),
154+
a_ego=self.data['a_ego'].values[step_idx + 1:step_idx + FUTURE_PLAN_STEPS].tolist()
155155
)
156156
)
157157

@@ -222,7 +222,7 @@ def run_rollout(data_path, controller_type, model_path, debug=False):
222222
parser.add_argument("--data_path", type=str, required=True)
223223
parser.add_argument("--num_segs", type=int, default=100)
224224
parser.add_argument("--debug", action='store_true')
225-
parser.add_argument("--controller", default='simple', choices=available_controllers)
225+
parser.add_argument("--controller", default='pid', choices=available_controllers)
226226
args = parser.parse_args()
227227

228228
data_path = Path(args.data_path)
@@ -232,7 +232,7 @@ def run_rollout(data_path, controller_type, model_path, debug=False):
232232
elif data_path.is_dir():
233233
run_rollout_partial = partial(run_rollout, controller_type=args.controller, model_path=args.model_path, debug=False)
234234
files = sorted(data_path.iterdir())[:args.num_segs]
235-
results = process_map(run_rollout_partial, files, max_workers=16)
235+
results = process_map(run_rollout_partial, files, max_workers=16, chunksize=10)
236236
costs = [result[0] for result in results]
237237
costs_df = pd.DataFrame(costs)
238238
print(f"\nAverage lataccel_cost: {np.mean(costs_df['lataccel_cost']):>6.4}, average jerk_cost: {np.mean(costs_df['jerk_cost']):>6.4}, average total_cost: {np.mean(costs_df['total_cost']):>6.4}")

0 commit comments

Comments
 (0)