Skip to content

Commit 31bb529

Browse files
authored
Merge branch 'karpathy:master' into master
2 parents 19ca7b4 + 7ecd890 commit 31bb529

File tree

2 files changed

+35
-61
lines changed

2 files changed

+35
-61
lines changed

README.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# llm.c
22

3-
LLMs in simple, pure C/CUDA with no need for 245MB of PyTorch or 107MB of cPython. Current focus is on pretraining, in particular reproducing the [GPT-2](https://github.com/openai/gpt-2) and [GPT-3](https://arxiv.org/abs/2005.14165) miniseries, along with a parallel PyTorch reference implementation in [train_gpt2.py](train_gpt2.py). You'll recognize this file as a slightly tweaked [nanoGPT](https://github.com/karpathy/nanoGPT), an earlier project of mine. Currently, llm.c is a bit faster than PyTorch Nightly (by about 7%). In addition to the bleeding edge mainline code in [train_gpt2.cu](train_gpt2.cu), we have a simple reference CPU fp32 implementation in ~1,000 lines of clean code in one file [train_gpt2.c](train_gpt2.c). I'd like this repo to only maintain C and CUDA code. Ports to other languages or repos are very welcome, but should be done in separate repos, and I am happy to link to them below in the "notable forks" section. Developer coordination happens in the [Discussions](https://github.com/karpathy/llm.c/discussions) and on Discord, either the `#llmc` channel on the [Zero to Hero](https://discord.gg/3zy8kqD9Cp) channel, or on `#llmdotc` on CUDA MODE Discord.
3+
LLMs in simple, pure C/CUDA with no need for 245MB of PyTorch or 107MB of cPython. Current focus is on pretraining, in particular reproducing the [GPT-2](https://github.com/openai/gpt-2) and [GPT-3](https://arxiv.org/abs/2005.14165) miniseries, along with a parallel PyTorch reference implementation in [train_gpt2.py](train_gpt2.py). You'll recognize this file as a slightly tweaked [nanoGPT](https://github.com/karpathy/nanoGPT), an earlier project of mine. Currently, llm.c is a bit faster than PyTorch Nightly (by about 7%). In addition to the bleeding edge mainline code in [train_gpt2.cu](train_gpt2.cu), we have a simple reference CPU fp32 implementation in ~1,000 lines of clean code in one file [train_gpt2.c](train_gpt2.c). I'd like this repo to only maintain C and CUDA code. Ports to other languages or repos are very welcome, but should be done in separate repos, and I am happy to link to them below in the "notable forks" section. Developer coordination happens in the [Discussions](https://github.com/karpathy/llm.c/discussions) and on Discord, either the `#llmc` channel on the [Zero to Hero](https://discord.gg/3zy8kqD9Cp) channel, or on `#llmdotc` on [GPU MODE](https://discord.gg/gpumode) Discord.
44

55
## quick start
66

@@ -211,10 +211,16 @@ Lastly, I will be a lot more sensitive to complexity in the root folder of the p
211211

212212
- CUDA C++
213213
- [llm.cpp](https://github.com/gevtushenko/llm.c) by @[gevtushenko](https://github.com/gevtushenko): a port of this project using the [CUDA C++ Core Libraries](https://github.com/NVIDIA/cccl)
214-
- A presentation this fork was covered in [this lecture](https://www.youtube.com/watch?v=WiB_3Csfj_Q) in the [CUDA MODE Discord Server](https://discord.gg/cudamode)
214+
- A presentation this fork was covered in [this lecture](https://www.youtube.com/watch?v=WiB_3Csfj_Q) in the [GPU MODE Discord Server](https://discord.gg/cudamode)
215+
216+
- C++/CUDA
217+
- [llm.cpp](https://github.com/zhangpiu/llm.cpp/tree/master/llmcpp) by @[zhangpiu](https://github.com/zhangpiu): a port of this project using the [Eigen](https://gitlab.com/libeigen/eigen), supporting CPU/CUDA.
215218

216219
- WebGPU C++
217220
- [gpu.cpp](https://github.com/AnswerDotAI/gpu.cpp) by @[austinvhuang](https://github.com/austinvhuang): a library for portable GPU compute in C++ using native WebGPU. Aims to be a general-purpose library, but also porting llm.c kernels to WGSL.
221+
222+
- C++
223+
- [llm.cpp](https://github.com/GaoYusong/llm.cpp) by @[GaoYusong](https://github.com/GaoYusong): a port of this project featuring a C++ single-header [tinytorch.hpp](https://github.com/GaoYusong/llm.cpp/blob/main/tinytorch.hpp) library
218224

219225
- Go
220226
- [llm.go](https://github.com/joshcarp/llm.go) by @[joshcarp](https://github.com/joshcarp): a Go port of this project

train_llama3.py

Lines changed: 27 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
TODO: add the actual commands
1717
"""
1818

19+
import argparse
1920
import os
2021
import math
2122
import glob
2223
import inspect
2324
from contextlib import nullcontext
2425
from dataclasses import dataclass
25-
import json
2626
from pathlib import Path
27+
import time
2728
from typing import (
2829
AbstractSet,
29-
Callable,
3030
Collection,
3131
Dict,
3232
Iterator,
@@ -55,9 +55,6 @@
5555
# -----------------------------------------------------------------------------
5656
# PyTorch nn.Module definitions for the LLaMA 3.x model
5757

58-
# using a global to toggle flash-attention
59-
FLASH = 0
60-
6158
# Used in Grouped Query Attention (GQA), broadcasts the key and value tensors
6259
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
6360
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
@@ -157,6 +154,7 @@ def __init__(self, config):
157154
self.n_rep = self.n_head // self.n_kv_head
158155
self.hd = config.n_embd // config.n_head
159156
self.use_kv = config.use_kv
157+
self.flash = config.flash
160158

161159
self.c_attn = nn.Linear(config.n_embd, (config.n_head + 2 * config.n_kv_head) * self.hd, bias=False) # key, query, value projections
162160
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # output projection
@@ -186,9 +184,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None):
186184

187185
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) # (B, NH, T, HD)
188186

189-
if FLASH:
187+
if self.flash:
190188
# flashattention
191-
y = F.scaled_dot_product_attention(q, k, v, mask)
189+
# if T == 1 no need to mask, otherwise the function complains
190+
# scaled_dot_product_attention expects a mask where value of True indicates that the element should take part in attention
191+
# our mask is the opposite, so we need to invert it
192+
y = F.scaled_dot_product_attention(q, k, v, mask == 0 if T > 1 else None)
192193
else:
193194
# manual implementation of attention
194195
# this materializes the large (T,T) matrix for all the queries and keys
@@ -257,6 +258,7 @@ class LlamaConfig:
257258
use_scaled_rope: bool = True
258259
max_gen_batch_size: int = 4
259260
use_kv: bool = True
261+
flash: bool = False # use flashattention?
260262

261263
def __init__(self, **kwargs):
262264
for k, v in kwargs.items():
@@ -402,7 +404,7 @@ def unpermute(w, n_heads, dim1, dim2):
402404
def from_pretrained_llama3_hf(cls, model_id):
403405
"""Loads pretrained LLaMA model weights from HuggingFace"""
404406
from transformers import AutoModelForCausalLM, AutoTokenizer
405-
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-bae model is supported for now"
407+
assert model_id == "meta-llama/Meta-Llama-3.1-8B", "Only the 8B-base model is supported for now"
406408
model_args = LlamaConfig()
407409

408410
model = AutoModelForCausalLM.from_pretrained(model_id)
@@ -477,7 +479,6 @@ def generate(
477479
max_gen_len: int,
478480
temperature: float = 0.6,
479481
top_p: float = 0.9,
480-
logprobs: bool = False,
481482
echo: bool = False,
482483
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
483484
"""
@@ -488,45 +489,35 @@ def generate(
488489
max_gen_len (int): Maximum length of the generated text sequence.
489490
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
490491
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
491-
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
492492
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
493493
494494
Returns:
495-
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
495+
Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences.
496496
497497
Note:
498498
This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
499-
If logprobs is True, token log probabilities are computed for each generated token.
500499
501500
"""
502501
bsz = len(prompt_tokens)
503-
assert bsz <= self.config.max_gen_batch_size, (bsz, self.config.max_gen_batch_size)
502+
assert bsz <= self.config.max_gen_batch_size, f"Batch size {bsz} exceeds the maximum generation batch size {self.config.max_gen_batch_size}"
504503
device = next(self.parameters()).device
505504

506505
min_prompt_len = min(len(t) for t in prompt_tokens)
507506
max_prompt_len = max(len(t) for t in prompt_tokens)
508-
assert max_prompt_len <= self.config.block_size
507+
assert max_prompt_len <= self.config.block_size, f"Prompt length {max_prompt_len} exceeds the maximum block size {self.config.block_size}"
509508
total_len = min(self.config.block_size, max_gen_len + max_prompt_len)
510509

511510
pad_id = self.tokenizer.pad_id
512511
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
513-
for k, t in enumerate(prompt_tokens):
514-
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
515-
if logprobs:
516-
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
512+
for idx, t in enumerate(prompt_tokens):
513+
tokens[idx, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
517514

518515
prev_pos = 0
519516
eos_reached = torch.tensor([False] * bsz, device=device)
520517
input_text_mask = tokens != pad_id
521518

522519
if min_prompt_len == total_len:
523520
logits, _ = self.forward(tokens, start_pos=prev_pos)
524-
token_logprobs = -F.cross_entropy(
525-
input=logits.transpose(1, 2),
526-
target=tokens,
527-
reduction="none",
528-
ignore_index=pad_id,
529-
)
530521

531522
stop_tokens = torch.tensor(list(self.tokenizer.stop_tokens)).to(device)
532523

@@ -542,41 +533,25 @@ def generate(
542533
# only replace token if prompt has already been generated
543534
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
544535
tokens[:, cur_pos] = next_token
545-
if logprobs:
546-
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
547-
input=logits.transpose(1, 2),
548-
target=tokens[:, prev_pos + 1 : cur_pos + 1],
549-
reduction="none",
550-
ignore_index=pad_id,
551-
)
552-
eos_reached |= (~input_text_mask[:, cur_pos]) & (
553-
torch.isin(next_token, stop_tokens)
554-
)
536+
eos_reached |= ~input_text_mask[:, cur_pos] & torch.isin(next_token, stop_tokens)
555537
prev_pos = cur_pos
556538
if all(eos_reached):
557539
break
558540

559-
if logprobs:
560-
token_logprobs = token_logprobs.tolist()
561-
out_tokens, out_logprobs = [], []
541+
out_tokens = []
562542
for i, toks in enumerate(tokens.tolist()):
563543
# cut to max gen len
564544
start = 0 if echo else len(prompt_tokens[i])
565545
toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
566-
probs = None
567-
if logprobs:
568-
probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
569546
# cut to after eos tok if any
570547
for stop_token in self.tokenizer.stop_tokens:
571548
try:
572549
eos_idx = toks.index(stop_token)
573550
toks = toks[:eos_idx]
574-
probs = probs[:eos_idx] if logprobs else None
575551
except ValueError:
576552
pass
577553
out_tokens.append(toks)
578-
out_logprobs.append(probs)
579-
return (out_tokens, out_logprobs if logprobs else None)
554+
return out_tokens
580555

581556
# -----------------------------------------------------------------------------
582557
# sampling utils
@@ -959,18 +934,16 @@ def print0(*args, **kwargs):
959934
print(*args, **kwargs)
960935

961936
if __name__ == "__main__":
962-
import time
963-
import argparse
964937
print0(f"Running pytorch {torch.version.__version__}")
965938

966939
# default settings will overfit a tiny batch of data
967940
# and save model weights and debug state to disk on the first iteration
968941
parser = argparse.ArgumentParser()
969942
parser.add_argument("--use_hf", type=int, default=1, help="use HuggingFace (default) or use Meta's model")
970-
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint")
971-
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer")
943+
parser.add_argument("--ckpt_dir", type=str, default=None, help="path to llama3 model checkpoint (needed if use_hf=0)")
944+
parser.add_argument("--tokenizer_path", type=str, default=None, help="path to llama3 tokenizer (needed if use_hf=0)")
972945
# file system input / output
973-
parser.add_argument("--input_bin", type=str, default="dev/data/tinystories/TinyStories_val.bin", help="input .bin to train on")
946+
parser.add_argument("--input_bin", type=str, default="dev/data/tinyshakespeare/tiny_shakespeare_val.bin", help="input .bin to train on")
974947
parser.add_argument("--input_val_bin", type=str, default="", help="input .bin to eval validation loss on")
975948
parser.add_argument("--output_dir", type=str, default="", help="output directory to which to write logs and checkpoints")
976949
parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B", help="chose the llama model")
@@ -982,7 +955,7 @@ def print0(*args, **kwargs):
982955
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
983956
parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
984957
# optimization
985-
parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate warmup iterations")
958+
parser.add_argument("--learning_rate", type=float, default=1e-5, help="learning rate warmup iterations")
986959
parser.add_argument("--warmup_iters", type=int, default=0, help="learning rate warmup iterations")
987960
parser.add_argument("--learning_rate_decay_frac", type=float, default=1.0, help="learning rate warmup iterations")
988961
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay")
@@ -998,7 +971,6 @@ def print0(*args, **kwargs):
998971
# memory management
999972
parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
1000973
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
1001-
parser.add_argument("--flash", type=int, default=0, help="use flash attention")
1002974
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|float16|bfloat16")
1003975
parser.add_argument("--zero_stage", type=int, default=0, help="zero redundancy optimizer stage (0/1/2/3)")
1004976
# python -> C bridge
@@ -1052,9 +1024,9 @@ def print0(*args, **kwargs):
10521024
device = "cuda"
10531025
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
10541026
device = "mps"
1055-
print(f"using device: {device}")
10561027
device_type = 'cuda' if 'cuda' in device else 'cpu'
1057-
assert device_type in {'cuda'} # we need to load LLaMA as bf16 on CUDA
1028+
assert device_type in {'cuda'}, "GPU required to run LLaMA 3" # we need to load LLaMA as bf16 on CUDA
1029+
print(f"using device: {device}")
10581030

10591031
# calculate gradient accumulation from the desired total batch size and the current run configuration
10601032
tokens_per_fwdbwd = B * T * ddp_world_size
@@ -1077,16 +1049,12 @@ def print0(*args, **kwargs):
10771049
if args.tensorcores:
10781050
torch.set_float32_matmul_precision('high')
10791051

1080-
# turn on/off flash attention
1081-
assert args.flash in {0, 1}
1082-
FLASH = args.flash
1083-
10841052
# init the model
1085-
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
1086-
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
10871053
if args.use_hf:
10881054
model = LLaMA.from_pretrained_llama3_hf(args.model)
10891055
else: # use Meta's checkpoint
1056+
assert args.ckpt_dir is not None and os.path.exists(args.ckpt_dir), f"llama3 ckpt dir {args.ckpt_dir} does not exist"
1057+
assert args.tokenizer_path is not None and os.path.exists(args.tokenizer_path), f"llama3 tokenizer path {args.tokenizer_path} does not exist"
10901058
model = LLaMA.from_pretrained_llama3_meta(args.ckpt_dir, args.tokenizer_path)
10911059

10921060
model.train()
@@ -1201,7 +1169,7 @@ def get_lr(it):
12011169
else: # Meta
12021170
prompt_tokens = [model.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
12031171

1204-
generation_tokens, _ = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, logprobs=False, echo=False)
1172+
generation_tokens = model.generate(prompt_tokens, max_gen_len=64, temperature=0.6, top_p=0.9, echo=False)
12051173
results = [{"generation": model.tokenizer.decode(t)} for t in generation_tokens]
12061174
for prompt, result in zip(prompts, results):
12071175
print(prompt, end="")

0 commit comments

Comments
 (0)