Skip to content

Commit 83a0411

Browse files
authored
feat: joblib cache (#749)
* cache function * fix test * bin cache * fix test * fix test * fix test * cache for different source * cache for localenv * remove unnecessary log * reformat * remove unrelated modify
1 parent 6c4db40 commit 83a0411

File tree

7 files changed

+65
-8
lines changed

7 files changed

+65
-8
lines changed

rdagent/components/coder/data_science/feature/eval_tests/feature_test.txt

+13-1
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,22 @@ X_loaded = deepcopy(X)
2828
y_loaded = deepcopy(y)
2929
X_test_loaded = deepcopy(X_test)
3030

31+
import sys
32+
import reprlib
33+
from joblib.memory import MemorizedFunc
34+
35+
36+
def get_original_code(func):
37+
if isinstance(func, MemorizedFunc):
38+
return func.func.__code__
39+
return func.__code__
40+
41+
3142
def debug_info_print(func):
3243
def wrapper(*args, **kwargs):
44+
original_code = get_original_code(func)
3345
def local_trace(frame, event, arg):
34-
if event == "return" and frame.f_code == func.__code__:
46+
if event == "return" and frame.f_code == original_code:
3547
print("\n" + "="*20 + "Running feat_eng code, local variable values:" + "="*20)
3648
for k, v in frame.f_locals.items():
3749
printed = aRepr.repr(v)

rdagent/components/coder/data_science/feature/prompts.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@ feature_coder:
3939
```python
4040
{{ data_loader_code }}
4141
```
42-
3. **Additional Guidance:**
42+
4. **Additional Guidance:**
4343
- If a previous attempt exists, improve upon it without repeating mistakes.
4444
- If errors indicate a missing file, find a way to download it or implement an alternative solution.
4545
- You should avoid using logging module to output information in your generated code, and instead use the print() function.
46+
5. You should use the following cache decorator to cache the results of the function:
47+
```python
48+
from joblib import Memory
49+
memory = Memory(location='/tmp/cache', verbose=0)
50+
@memory.cache```
4651
4752
## Output Format
4853
{% if out_spec %}

rdagent/components/coder/data_science/model/eval_tests/model_test.txt

+14-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ print(f"test_ids length: {len(test_ids)}")
3535

3636
train_X, val_X, train_y, val_y = train_test_split(X, y, test_size=0.8, random_state=42)
3737

38+
39+
import sys
40+
import reprlib
41+
from joblib.memory import MemorizedFunc
42+
43+
44+
def get_original_code(func):
45+
if isinstance(func, MemorizedFunc):
46+
return func.func.__code__
47+
return func.__code__
48+
3849
print("train_X:", aRepr.repr(train_X))
3950
print("train_y:", aRepr.repr(train_y))
4051
print("val_X:", aRepr.repr(val_X))
@@ -46,10 +57,12 @@ print(f"val_X.shape: {val_X.shape}" if hasattr(val_X, 'shape') else f"val_X leng
4657
print(f"val_y.shape: {val_y.shape}" if hasattr(val_y, 'shape') else f"val_y length: {len(val_y)}")
4758

4859

60+
4961
def debug_info_print(func):
5062
def wrapper(*args, **kwargs):
63+
original_code = get_original_code(func)
5164
def local_trace(frame, event, arg):
52-
if event == "return" and frame.f_code == func.__code__:
65+
if event == "return" and frame.f_code == original_code:
5366
print("\n" + "="*20 + "Running model training code, local variable values:" + "="*20)
5467
for k, v in frame.f_locals.items():
5568
printed = aRepr.repr(v)

rdagent/components/coder/data_science/model/prompts.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ model_coder:
4040
{{ feature_code }}
4141
2. You should avoid using logging module to output information in your generated code, and instead use the print() function.
4242
3. If the model can both be implemented by PyTorch and Tensorflow, please use pytorch for broader compatibility.
43+
4. You should use the following cache decorator to cache the results of the function:
44+
```python
45+
from joblib import Memory
46+
memory = Memory(location='/tmp/cache', verbose=0)
47+
@memory.cache``
4348
4449
## Output Format
4550
{% if out_spec %}

rdagent/components/coder/data_science/raw_data_loader/eval_tests/data_loader_test.txt

+11-1
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,22 @@ from load_data import load_data
99

1010
import sys
1111
import reprlib
12+
from joblib.memory import MemorizedFunc
13+
14+
15+
def get_original_code(func):
16+
if isinstance(func, MemorizedFunc):
17+
return func.func.__code__
18+
return func.__code__
19+
20+
1221
def debug_info_print(func):
1322
aRepr = reprlib.Repr()
1423
aRepr.maxother=300
1524
def wrapper(*args, **kwargs):
25+
original_code = get_original_code(func)
1626
def local_trace(frame, event, arg):
17-
if event == "return" and frame.f_code == func.__code__:
27+
if event == "return" and frame.f_code == original_code:
1828
print("\n" + "="*20 + "Running data_load code, local variable values:" + "="*20)
1929
for k, v in frame.f_locals.items():
2030
printed = aRepr.repr(v)

rdagent/components/coder/data_science/raw_data_loader/prompts.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,11 @@ data_loader_coder:
334334
## Guidelines
335335
1. Ensure that the dataset is loaded strictly from `/kaggle/input/`, following the exact folder structure described in the **Data Folder Description**, and do not attempt to load data from the current directory (`./`).
336336
2. You should avoid using logging module to output information in your generated code, and instead use the print() function.
337+
3. You should use the following cache decorator to cache the results of the function:
338+
```python
339+
from joblib import Memory
340+
memory = Memory(location='/tmp/cache', verbose=0)
341+
@memory.cache```
337342
338343
## Exploratory Data Analysis (EDA) part(Required):
339344
- Before returning the data, you should always add an EDA part describing the data to help the following steps understand the data better.

rdagent/utils/env.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
The motiviation of the utils is for environment management
2+
The motivation of the utils is for environment management
33
44
Tries to create uniform environment for the agent to run;
55
- All the code and data is expected included in one folder
@@ -210,8 +210,8 @@ def cached_run(
210210
target_folder = Path(RD_AGENT_SETTINGS.pickle_cache_folder_path_str) / f"utils.env.run"
211211
target_folder.mkdir(parents=True, exist_ok=True)
212212

213-
# we must add the information of data (beyound code) into the key.
214-
# Otherwise, all commands operating on data will become invalue (e.g. rm -r submission.csv)
213+
# we must add the information of data (beyond code) into the key.
214+
# Otherwise, all commands operating on data will become invalid (e.g. rm -r submission.csv)
215215
# So we recursively walk in the folder and add the sorted relative filename list as part of the key.
216216
data_key = []
217217
for path in Path(local_path).rglob("*"):
@@ -292,7 +292,7 @@ class LocalConf(EnvConf):
292292

293293
class LocalEnv(Env[ASpecificLocalConf]):
294294
"""
295-
Sometimes local environment may be more convinient for testing
295+
Sometimes local environment may be more convenient for testing
296296
"""
297297

298298
def prepare(self) -> None: ...
@@ -311,6 +311,9 @@ def _run_ret_code(
311311
if self.conf.extra_volumes is not None:
312312
for lp, rp in self.conf.extra_volumes.items():
313313
volumes[lp] = rp
314+
cache_path = "/tmp/sample" if "/sample/" in "".join(self.conf.extra_volumes.keys()) else "/tmp/full"
315+
Path(cache_path).mkdir(parents=True, exist_ok=True)
316+
volumes[cache_path] = "/tmp/cache"
314317
for lp, rp in running_extra_volume.items():
315318
volumes[lp] = rp
316319

@@ -605,9 +608,13 @@ def _run_ret_code(
605608
if local_path is not None:
606609
local_path = os.path.abspath(local_path)
607610
volumes[local_path] = {"bind": self.conf.mount_path, "mode": "rw"}
611+
608612
if self.conf.extra_volumes is not None:
609613
for lp, rp in self.conf.extra_volumes.items():
610614
volumes[lp] = {"bind": rp, "mode": self.conf.extra_volume_mode}
615+
cache_path = "/tmp/sample" if "/sample/" in "".join(self.conf.extra_volumes.keys()) else "/tmp/full"
616+
Path(cache_path).mkdir(parents=True, exist_ok=True)
617+
volumes[cache_path] = {"bind": "/tmp/cache", "mode": "rw"}
611618
for lp, rp in running_extra_volume.items():
612619
volumes[lp] = {"bind": rp, "mode": self.conf.extra_volume_mode}
613620

0 commit comments

Comments
 (0)