|
16 | 16 | # See the License for the specific language governing permissions and
|
17 | 17 | # limitations under the License.
|
18 | 18 |
|
| 19 | +import io |
19 | 20 | import os
|
| 21 | +import re |
| 22 | +import shlex |
20 | 23 | import subprocess
|
21 | 24 | import sys
|
| 25 | +from contextlib import redirect_stderr, redirect_stdout |
22 | 26 |
|
23 | 27 | import pytest
|
24 | 28 | import torch
|
@@ -73,6 +77,58 @@ def test_train_evo2_runs(tmp_path, num_steps=5):
|
73 | 77 | assert result.returncode == 0, "train_evo2 command failed."
|
74 | 78 |
|
75 | 79 |
|
| 80 | +@pytest.mark.timeout(256) # Optional: fail if the test takes too long. |
| 81 | +@pytest.mark.slow |
| 82 | +def test_train_evo2_stops(tmp_path, num_steps=500000, early_stop_steps=3): |
| 83 | + """ |
| 84 | + This test runs the `train_evo2` command with mock data in a temporary directory. |
| 85 | + It uses the temporary directory provided by pytest as the working directory. |
| 86 | + The command is run in a subshell, and we assert that it returns an exit code of 0. |
| 87 | + """ |
| 88 | + open_port = find_free_network_port() |
| 89 | + # a local copy of the environment |
| 90 | + env = dict(**os.environ) |
| 91 | + env["MASTER_PORT"] = str(open_port) |
| 92 | + |
| 93 | + # Build the command string. |
| 94 | + # Note: The command assumes that `train_evo2` is in your PATH. |
| 95 | + command = ( |
| 96 | + f"train_evo2 --mock-data --experiment-dir {tmp_path}/test_train " |
| 97 | + "--model-size 1b_nv --num-layers 4 --hybrid-override-pattern SDH* " |
| 98 | + "--no-activation-checkpointing --add-bias-output " |
| 99 | + f"--max-steps {num_steps} --early-stop-on-step {early_stop_steps} --warmup-steps 1 --no-wandb " |
| 100 | + "--seq-length 128 --hidden-dropout 0.1 --attention-dropout 0.1 " |
| 101 | + ) |
| 102 | + command_parts_no_program = shlex.split(command)[1:] |
| 103 | + args = parse_args(args=command_parts_no_program) |
| 104 | + with distributed_model_parallel_state(): |
| 105 | + # Capture stdout/stderr during train function execution |
| 106 | + stdout_buffer = io.StringIO() |
| 107 | + stderr_buffer = io.StringIO() |
| 108 | + with redirect_stdout(stdout_buffer), redirect_stderr(stderr_buffer): |
| 109 | + train(args=args) |
| 110 | + # Get the captured output |
| 111 | + train_stdout = stdout_buffer.getvalue() |
| 112 | + train_stderr = stderr_buffer.getvalue() |
| 113 | + # Print the captured output for debugging |
| 114 | + print("TRAIN FUNCTION STDOUT:") |
| 115 | + print(train_stdout) |
| 116 | + print("TRAIN FUNCTION STDERR:") |
| 117 | + print(train_stderr) |
| 118 | + |
| 119 | + # Assert that the command completed successfully. |
| 120 | + assert "reduced_train_loss:" in train_stdout |
| 121 | + pattern = r"\| global_step: (\d+) \|" |
| 122 | + |
| 123 | + def extract_global_steps(log_string): |
| 124 | + matches = re.findall(pattern, log_string) |
| 125 | + return [int(step) for step in matches] |
| 126 | + |
| 127 | + global_step_ints = extract_global_steps(train_stdout) |
| 128 | + assert global_step_ints[-1] == early_stop_steps - 1 |
| 129 | + assert len(global_step_ints) == early_stop_steps |
| 130 | + |
| 131 | + |
76 | 132 | @pytest.mark.slow
|
77 | 133 | @pytest.mark.parametrize("model_size", ["7b_nv", "7b_arc_longcontext"])
|
78 | 134 | def test_train_single_gpu(tmp_path, model_size: str):
|
|
0 commit comments