Skip to content

Commit 7428f5f

Browse files
authored
If desired, training can be stopped on a specific step without impacting the LR curve. (#739)
See new option `--early-stop-on-step` which may be used in JET tests to stop on a specific step without impacting the LR curve by changing `--max-steps` --------- Signed-off-by: John St John <[email protected]>
1 parent 32e402a commit 7428f5f

File tree

3 files changed

+96
-5
lines changed

3 files changed

+96
-5
lines changed

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from nemo.utils.exp_manager import TimingCallback
4949

5050
from bionemo.llm.utils.datamodule_utils import infer_global_batch_size
51+
from bionemo.testing.testing_callbacks import SignalAfterGivenStepCallback
5152

5253

5354
torch._dynamo.config.suppress_errors = True
@@ -119,7 +120,18 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
119120
parser.add_argument(
120121
"--grad-acc-batches", type=int, default=1, help="Number of batches to accumulate gradients over."
121122
)
122-
parser.add_argument("--max-steps", type=int, help="Number of training optimizer update steps.")
123+
parser.add_argument(
124+
"--max-steps",
125+
type=int,
126+
help="Number of training optimizer update steps. This controls the total number of steps as well as the "
127+
"shape of the learning rate curve.",
128+
default=500000,
129+
)
130+
parser.add_argument(
131+
"--early-stop-on-step",
132+
type=int,
133+
help="Stop training on this step, if set. This may be useful for testing or debugging purposes.",
134+
)
123135
parser.add_argument(
124136
"--val-check-interval", type=int, help="Number of steps between validation measurements and model checkpoints."
125137
)
@@ -468,7 +480,13 @@ def train(args: argparse.Namespace):
468480
save_context_on_train_end=True,
469481
)
470482
callbacks.append(checkpoint_callback)
471-
483+
if args.early_stop_on_step:
484+
# Ask the trainer to stop by setting should_stop to True rather than emitting a kill signal.
485+
callbacks.append(
486+
SignalAfterGivenStepCallback(
487+
stop_step=args.early_stop_on_step, stop_before_step=True, use_trainer_should_stop=True
488+
)
489+
)
472490
if args.enable_preemption:
473491
callbacks.append(nl_callbacks.PreemptionCallback())
474492
if args.debug_ddp_parity_freq > 0:

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_train.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19+
import io
1920
import os
21+
import re
22+
import shlex
2023
import subprocess
2124
import sys
25+
from contextlib import redirect_stderr, redirect_stdout
2226

2327
import pytest
2428
import torch
@@ -73,6 +77,58 @@ def test_train_evo2_runs(tmp_path, num_steps=5):
7377
assert result.returncode == 0, "train_evo2 command failed."
7478

7579

80+
@pytest.mark.timeout(256) # Optional: fail if the test takes too long.
81+
@pytest.mark.slow
82+
def test_train_evo2_stops(tmp_path, num_steps=500000, early_stop_steps=3):
83+
"""
84+
This test runs the `train_evo2` command with mock data in a temporary directory.
85+
It uses the temporary directory provided by pytest as the working directory.
86+
The command is run in a subshell, and we assert that it returns an exit code of 0.
87+
"""
88+
open_port = find_free_network_port()
89+
# a local copy of the environment
90+
env = dict(**os.environ)
91+
env["MASTER_PORT"] = str(open_port)
92+
93+
# Build the command string.
94+
# Note: The command assumes that `train_evo2` is in your PATH.
95+
command = (
96+
f"train_evo2 --mock-data --experiment-dir {tmp_path}/test_train "
97+
"--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* "
98+
"--no-activation-checkpointing --add-bias-output "
99+
f"--max-steps {num_steps} --early-stop-on-step {early_stop_steps} --warmup-steps 1 --no-wandb "
100+
"--seq-length 128 --hidden-dropout 0.1 --attention-dropout 0.1 "
101+
)
102+
command_parts_no_program = shlex.split(command)[1:]
103+
args = parse_args(args=command_parts_no_program)
104+
with distributed_model_parallel_state():
105+
# Capture stdout/stderr during train function execution
106+
stdout_buffer = io.StringIO()
107+
stderr_buffer = io.StringIO()
108+
with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer):
109+
train(args=args)
110+
# Get the captured output
111+
train_stdout = stdout_buffer.getvalue()
112+
train_stderr = stderr_buffer.getvalue()
113+
# Print the captured output for debugging
114+
print("TRAIN FUNCTION STDOUT:")
115+
print(train_stdout)
116+
print("TRAIN FUNCTION STDERR:")
117+
print(train_stderr)
118+
119+
# Assert that the command completed successfully.
120+
assert "reduced_train_loss:" in train_stdout
121+
pattern = r"\| global_step: (\d+) \|"
122+
123+
def extract_global_steps(log_string):
124+
matches = re.findall(pattern, log_string)
125+
return [int(step) for step in matches]
126+
127+
global_step_ints = extract_global_steps(train_stdout)
128+
assert global_step_ints[-1] == early_stop_steps - 1
129+
assert len(global_step_ints) == early_stop_steps
130+
131+
76132
@pytest.mark.slow
77133
@pytest.mark.parametrize("model_size", ["7b_nv", "7b_arc_longcontext"])
78134
def test_train_single_gpu(tmp_path, model_size: str):

sub-packages/bionemo-testing/src/bionemo/testing/testing_callbacks.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,32 @@ class SignalAfterGivenStepCallback(Callback, CallbackMethods):
4949
Use this callback for pytest based Stop and go tests.
5050
"""
5151

52-
def __init__(self, stop_step: int, signal_: signal.Signals = signal.SIGUSR2):
52+
def __init__(
53+
self,
54+
stop_step: int,
55+
signal_: signal.Signals = signal.SIGUSR2,
56+
use_trainer_should_stop: bool = False,
57+
stop_before_step: bool = False,
58+
):
5359
"""Initializes the callback with the given stop_step."""
54-
self.stop_step = stop_step
60+
# Note that the stop step will be one less than the requested step if stop_before_step is True.
61+
# this is because the first step is 0 so you get i+1 steps normally.
62+
if stop_before_step:
63+
self.stop_step = stop_step - 1
64+
else:
65+
self.stop_step = stop_step
5566
self.signal = signal_
67+
# If True, ask the trainer to stop by setting should_stop to True rather than emitting a kill signal.
68+
self.use_trainer_should_stop = use_trainer_should_stop
5669

5770
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
5871
"""Stop training if the global step is greater than or equal to the stop_step."""
5972
if step.trainer.global_step >= self.stop_step:
60-
os.kill(os.getpid(), self.signal)
73+
if self.use_trainer_should_stop:
74+
# Ask the trainer to stop by setting should_stop to True rather than emitting a kill signal.
75+
step.trainer.should_stop = True
76+
else:
77+
os.kill(os.getpid(), self.signal)
6178
return step
6279

6380

0 commit comments

Comments
 (0)