Skip to content

Commit 369672c

Browse files
authored
Merge branch 'karpathy:master' into master
2 parents d0e7a59 + 29aacba commit 369672c

12 files changed

+423
-62
lines changed

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ $(info ---------------------------------------------)
6262

6363
ifneq ($(OS), Windows_NT)
6464
NVCC := $(shell which nvcc 2>/dev/null)
65+
NVCC_LDFLAGS += -lnvidia-ml
6566

6667
# Function to test if the compiler accepts a given flag.
6768
define check_and_add_flag

dev/cuda/Makefile

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ ifeq ($(NVCC),)
99
endif
1010

1111
ifneq ($(CI),true) # if not in CI, then use the GPU query
12-
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
12+
ifndef GPU_COMPUTE_CAPABILITY # set to defaults if: make GPU_COMPUTE_CAPABILITY=
1313
GPU_COMPUTE_CAPABILITY = $(shell __nvcc_device_query) # assume if NVCC is present, then this likely is too
1414
GPU_COMPUTE_CAPABILITY := $(strip $(GPU_COMPUTE_CAPABILITY))
1515
endif
1616
endif
1717

1818
# Compiler flags
19-
ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY=
20-
CFLAGS = -O3 --use_fast_math
19+
ifeq ($(GPU_COMPUTE_CAPABILITY),) # set to defaults if: make GPU_COMPUTE_CAPABILITY=
20+
CFLAGS = -O3 --use_fast_math
2121
else
2222
CFLAGS = -O3 --use_fast_math --generate-code arch=compute_$(GPU_COMPUTE_CAPABILITY),code=[compute_$(GPU_COMPUTE_CAPABILITY),sm_$(GPU_COMPUTE_CAPABILITY)]
2323
endif
@@ -30,7 +30,8 @@ MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-
3030
$(NVCC) $(CFLAGS) $(NVCCFLAGS) $< -o $@
3131

3232
# Build all targets
33-
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm
33+
TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_backward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward fused_residual_forward global_norm permute
34+
3435
all: $(TARGETS)
3536
all_ptx: $(TARGETS:%=%.ptx)
3637
all_sass: $(TARGETS:%=%.sass)
@@ -64,6 +65,8 @@ matmul_backward: matmul_backward.cu
6465
adamw: adamw.cu
6566
global_norm: global_norm.cu
6667

68+
permute: permute.cu
69+
6770
# NCCL communication kernels
6871
nccl_all_reduce: nccl_all_reduce.cu
6972
$(NVCC) -lmpi -lnccl $(NVCCFLAGS) $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce

dev/cuda/attention_backward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,7 @@ int main(int argc, char **argv) {
11371137
free(dinp);
11381138
free(dpreatt);
11391139
free(datt);
1140+
free(h_dinp);
11401141
cudaCheck(cudaFree(d_inp));
11411142
cudaCheck(cudaFree(d_qkvr));
11421143
cudaCheck(cudaFree(d_preatt));

dev/cuda/attention_forward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,7 @@ int main(int argc, char **argv) {
13771377
cudaCheck(cudaFree(d_preatt));
13781378
cudaCheck(cudaFree(d_att));
13791379
cudaCheck(cudaFree(d_inp));
1380+
cudaCheck(cudaFree(d_stats));
13801381
cublasDestroy(cublas_handle);
13811382

13821383
#ifdef ENABLE_CUDNN

dev/cuda/classifier_fused.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ int main(int argc, char **argv) {
766766
cudaCheck(cudaFree(d_logits));
767767
cudaCheck(cudaFree(d_dlosses));
768768
cudaCheck(cudaFree(d_targets));
769+
cudaCheck(cudaFree(d_dlogits_no_pad));
769770

770771
return 0;
771772
}

dev/cuda/nccl_all_reduce.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,6 @@ int main(int argc, char **argv) {
193193

194194
free(all_reduce_buffer_host);
195195
cudaCheck(cudaFree(all_reduce_buffer));
196+
cudaCheck(cudaFree(all_reduce_buffer_recv));
196197
multi_gpu_config_free(&multi_gpu_config);
197198
}

dev/cuda/permute.cu

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
Kernels to demonstrate permute operation.
3+
4+
Compile example:
5+
nvcc -O3 permute.cu -o permute
6+
7+
The goal is to permute a 4D matrix from its original shape (dim1, dim2, dim3, dim4) to a new shape (dim4, dim3, dim1, dim2).
8+
9+
Before permutation, we need to understand how to access elements in a flattened (linear) form of the matrix.
10+
11+
Given:
12+
13+
dim1 = size of the 1st dimension
14+
dim2 = size of the 2nd dimension
15+
dim3 = size of the 3rd dimension
16+
dim4 = size of the 4th dimension
17+
18+
For any element in a 4D matrix at position (i1, i2, i3, i4), where:
19+
20+
i1 is the index in dimension 1
21+
i2 is the index in dimension 2
22+
i3 is the index in dimension 3
23+
i4 is the index in dimension 4
24+
25+
If you find it challenging to calculate the indices i1, i2, i3, and i4, observe the pattern in the index calculations.
26+
Initially, it might take some time to grasp, but with practice, you'll develop a mental model for it.
27+
28+
To calculate the indices, use the following formulas:
29+
30+
i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
31+
i2 = (idx / (dim3 * dim4)) % dim2;
32+
i3 = (idx / dim4) % dim3;
33+
i4 = idx % dim4;
34+
35+
Pattern Explanation:
36+
To find the index for any dimension, divide the thread ID (idx) by the product of all subsequent dimensions.
37+
Then, perform modulo operation with the current dimension.
38+
39+
40+
41+
The linear index in a flattened 1D array is calculated as:
42+
linear_idx = i1 × ( dim2 × dim3 × dim4 ) + i2 × ( dim3 × dim4 ) + i3 × dim4 + i4
43+
This linear index uniquely identifies the position of the element in the 1D array.
44+
45+
To permute the matrix, we need to rearrange the indices according to the new shape.
46+
In this case, we are permuting from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2).
47+
48+
The new dimension post permutation will be as follows:
49+
50+
dim1 becomes the new 3rd dimension.
51+
dim2 becomes the new 4th dimension.
52+
dim3 becomes the new 2nd dimension.
53+
dim4 becomes the new 1st dimension.
54+
55+
permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
56+
57+
Here's how this works:
58+
59+
i4 * (dim3 * dim1 * dim2): This accounts for how many complete dim3 × dim1 × dim2 blocks fit before the current i4 block.
60+
i3 * (dim1 * dim2): This accounts for the offset within the current i4 block, specifying which i3 block we are in.
61+
i1 * dim2: This accounts for the offset within the current i3 block, specifying which i1 block we are in.
62+
i2: This gives the offset within the current i1 block.
63+
64+
Lastly at the end we store the current value at idx index of the original value to the permuted index in the permuted_matrix.
65+
66+
67+
--------------------------------------------------------------------------------------------------------------------------------------------------------
68+
69+
Similarly we can follow the above approach to permute matrices of any dimensions.
70+
71+
*/
72+
73+
74+
#include <cuda_runtime.h>
75+
#include <stdio.h>
76+
#include <stdlib.h>
77+
#include <cmath>
78+
79+
#include "common.h"
80+
81+
// CPU function to permute a 4D matrix
82+
void permute_cpu(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
83+
int total_threads = dim1 * dim2 * dim3 * dim4;
84+
85+
for (int idx = 0; idx < total_threads; idx++) {
86+
// Calculate the 4D indices from the linear index
87+
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
88+
int i2 = (idx / (dim3 * dim4)) % dim2;
89+
int i3 = (idx / dim4) % dim3;
90+
int i4 = idx % dim4;
91+
92+
// Compute the new index for the permuted matrix
93+
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
94+
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
95+
out_matrix[permuted_idx] = matrix[idx];
96+
}
97+
}
98+
99+
// CUDA kernel to permute a 4D matrix
100+
__global__ void permute_kernel(const float* matrix, float* out_matrix, int dim1, int dim2, int dim3, int dim4) {
101+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
102+
103+
// Ensure index is within bounds
104+
if (idx < dim1 * dim2 * dim3 * dim4) {
105+
// Calculate the 4D indices from the linear index
106+
int i1 = (idx / (dim2 * dim3 * dim4)) % dim1;
107+
int i2 = (idx / (dim3 * dim4)) % dim2;
108+
int i3 = (idx / dim4) % dim3;
109+
int i4 = idx % dim4;
110+
111+
// Compute the new index for the permuted matrix
112+
// Transpose from (dim1, dim2, dim3, dim4) to (dim4, dim3, dim1, dim2)
113+
int permuted_idx = i4 * (dim3 * dim1 * dim2) + i3 * (dim1 * dim2) + i1 * dim2 + i2;
114+
out_matrix[permuted_idx] = matrix[idx];
115+
}
116+
}
117+
118+
119+
int main() {
120+
int dim_1 = 24;
121+
int dim_2 = 42;
122+
int dim_3 = 20;
123+
int dim_4 = 32;
124+
125+
// Set up the device
126+
int deviceIdx = 0;
127+
cudaSetDevice(deviceIdx);
128+
cudaDeviceProp deviceProp;
129+
cudaGetDeviceProperties(&deviceProp, deviceIdx);
130+
printf("Device %d: %s\n", deviceIdx, deviceProp.name);
131+
132+
// Allocate host memory
133+
float* matrix = make_random_float(dim_1 * dim_2 * dim_3 * dim_4);
134+
float* permuted_matrix = (float*)malloc(dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
135+
136+
// Initialize the matrix with random values
137+
138+
// Allocate device memory
139+
float *d_matrix, *d_permuted_matrix;
140+
cudaMalloc(&d_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
141+
cudaMalloc(&d_permuted_matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float));
142+
143+
// Copy matrix from host to device
144+
cudaMemcpy(d_matrix, matrix, dim_1 * dim_2 * dim_3 * dim_4 * sizeof(float), cudaMemcpyHostToDevice);
145+
146+
// Perform permutation on CPU
147+
clock_t start = clock();
148+
permute_cpu(matrix, permuted_matrix, dim_1, dim_2, dim_3, dim_4);
149+
clock_t end = clock();
150+
double elapsed_time_cpu = (double)(end - start) / CLOCKS_PER_SEC;
151+
152+
// Define block and grid sizes
153+
dim3 blockSize(256);
154+
int totalThreads = dim_1 * dim_2 * dim_3 * dim_4;
155+
int gridSize = (totalThreads + blockSize.x - 1) / blockSize.x; // Compute grid size
156+
157+
// Launch CUDA kernel to perform permutation
158+
permute_kernel<<<gridSize, blockSize>>>(d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4);
159+
cudaDeviceSynchronize(); // Ensure kernel execution is complete
160+
161+
// Verify results
162+
printf("Checking correctness...\n");
163+
validate_result(d_permuted_matrix, permuted_matrix, "permuted_matrix", dim_1 * dim_2 * dim_3 * dim_4, 1e-5f);
164+
165+
printf("All results match.\n\n");
166+
// benchmark kernel
167+
int repeat_times = 1000;
168+
float elapsed_time = benchmark_kernel(repeat_times, permute_kernel,
169+
d_matrix, d_permuted_matrix, dim_1, dim_2, dim_3, dim_4
170+
);
171+
printf("time gpu %.4f ms\n", elapsed_time);
172+
printf("time cpu %.4f ms\n", elapsed_time_cpu);
173+
174+
// Free allocated memory
175+
free(matrix);
176+
free(permuted_matrix);
177+
cudaFree(d_matrix);
178+
cudaFree(d_permuted_matrix);
179+
180+
return 0;
181+
}

dev/cuda/trimat_forward.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ int main(int argc, char **argv) {
643643
free(inp);
644644
cudaCheck(cudaFree(d_out));
645645
cudaCheck(cudaFree(d_inp));
646+
cudaCheck(cudaFree(d_qkvr));
646647
cublasDestroy(cublas_handle);
647648

648649
return 0;

llmc/adamw.cuh

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
6161
);
6262
}
6363

64+
template <typename Tp>
65+
__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters,
66+
ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed) {
67+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
68+
if (idx >= num_parameters) { return; }
69+
params_memory += blockIdx.y * w_stride; // adjust for layer offset
70+
master_params_memory += blockIdx.y * s_stride;
71+
stochastic_rounding(master_params_memory[idx], &params_memory[idx], seed);
72+
}
73+
6474
template <typename Tp, typename Tg>
6575
void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
6676
ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
@@ -75,4 +85,14 @@ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memo
7585
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay,
7686
grad_scale, seed);
7787
cudaCheck(cudaGetLastError());
78-
}
88+
}
89+
90+
template <typename Tp>
91+
void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters,
92+
ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream) {
93+
int block_size = 512; // must match block size of adamw_update so that RNG also matches
94+
int num_blocks = CEIL_DIV(num_parameters, block_size);
95+
init_from_master_kernel<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>
96+
(params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed);
97+
cudaCheck(cudaGetLastError());
98+
}

0 commit comments

Comments
 (0)