Skip to content

Commit e2bd030

Browse files
authored
Merge branch 'ggerganov:master' into fix-vulkan-shader-warnings
2 parents 19aa132 + 4c676c8 commit e2bd030

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2005
-1599
lines changed

Makefile

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,9 @@ ifdef LLAMA_DEBUG
325325
endif
326326
else
327327
MK_CPPFLAGS += -DNDEBUG
328-
MK_CFLAGS += -O3
329-
MK_CXXFLAGS += -O3
330-
MK_NVCCFLAGS += -O3
328+
MK_CFLAGS += -O3 -g
329+
MK_CXXFLAGS += -O3 -g
330+
MK_NVCCFLAGS += -O3 -g
331331
endif
332332

333333
ifdef LLAMA_SANITIZE_THREAD
@@ -528,10 +528,21 @@ ifndef GGML_NO_ACCELERATE
528528
endif
529529
endif # GGML_NO_ACCELERATE
530530

531+
ifdef GGML_MUSA
532+
CC := clang
533+
CXX := clang++
534+
GGML_CUDA := 1
535+
MK_CPPFLAGS += -DGGML_USE_MUSA
536+
endif
537+
531538
ifndef GGML_NO_OPENMP
532539
MK_CPPFLAGS += -DGGML_USE_OPENMP
533540
MK_CFLAGS += -fopenmp
534541
MK_CXXFLAGS += -fopenmp
542+
ifdef GGML_MUSA
543+
MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp
544+
MK_LDFLAGS += -L/usr/lib/llvm-10/lib
545+
endif # GGML_MUSA
535546
endif # GGML_NO_OPENMP
536547

537548
ifdef GGML_OPENBLAS
@@ -582,15 +593,27 @@ else
582593
endif # GGML_CUDA_FA_ALL_QUANTS
583594

584595
ifdef GGML_CUDA
585-
ifneq ('', '$(wildcard /opt/cuda)')
586-
CUDA_PATH ?= /opt/cuda
596+
ifdef GGML_MUSA
597+
ifneq ('', '$(wildcard /opt/musa)')
598+
CUDA_PATH ?= /opt/musa
599+
else
600+
CUDA_PATH ?= /usr/local/musa
601+
endif
602+
603+
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
604+
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
605+
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
587606
else
588-
CUDA_PATH ?= /usr/local/cuda
589-
endif
607+
ifneq ('', '$(wildcard /opt/cuda)')
608+
CUDA_PATH ?= /opt/cuda
609+
else
610+
CUDA_PATH ?= /usr/local/cuda
611+
endif
590612

591-
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
592-
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
593-
MK_NVCCFLAGS += -use_fast_math
613+
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
614+
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
615+
MK_NVCCFLAGS += -use_fast_math
616+
endif # GGML_MUSA
594617

595618
OBJ_GGML += ggml/src/ggml-cuda.o
596619
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
@@ -600,9 +623,11 @@ ifdef LLAMA_FATAL_WARNINGS
600623
MK_NVCCFLAGS += -Werror all-warnings
601624
endif # LLAMA_FATAL_WARNINGS
602625

626+
ifndef GGML_MUSA
603627
ifndef JETSON_EOL_MODULE_DETECT
604628
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
605629
endif # JETSON_EOL_MODULE_DETECT
630+
endif # GGML_MUSA
606631

607632
ifdef LLAMA_DEBUG
608633
MK_NVCCFLAGS += -lineinfo
@@ -615,8 +640,12 @@ endif # GGML_CUDA_DEBUG
615640
ifdef GGML_CUDA_NVCC
616641
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
617642
else
618-
NVCC = $(CCACHE) nvcc
619-
endif #GGML_CUDA_NVCC
643+
ifdef GGML_MUSA
644+
NVCC = $(CCACHE) mcc
645+
else
646+
NVCC = $(CCACHE) nvcc
647+
endif # GGML_MUSA
648+
endif # GGML_CUDA_NVCC
620649

621650
ifdef CUDA_DOCKER_ARCH
622651
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
@@ -687,9 +716,15 @@ define NVCC_COMPILE
687716
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
688717
endef # NVCC_COMPILE
689718
else
719+
ifdef GGML_MUSA
720+
define NVCC_COMPILE
721+
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -c $< -o $@
722+
endef # NVCC_COMPILE
723+
else
690724
define NVCC_COMPILE
691725
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
692726
endef # NVCC_COMPILE
727+
endif # GGML_MUSA
693728
endif # JETSON_EOL_MODULE_DETECT
694729

695730
ggml/src/ggml-cuda/%.o: \
@@ -944,6 +979,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
944979
ifdef GGML_CUDA
945980
$(info I NVCC: $(shell $(NVCC) --version | tail -n 1))
946981
CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])')
982+
ifndef GGML_MUSA
947983
ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1)
948984

949985
ifndef CUDA_DOCKER_ARCH
@@ -953,6 +989,7 @@ endif # CUDA_POWER_ARCH
953989
endif # CUDA_DOCKER_ARCH
954990

955991
endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
992+
endif # GGML_MUSA
956993
endif # GGML_CUDA
957994
$(info )
958995

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
409409
| [BLAS](./docs/build.md#blas-build) | All |
410410
| [BLIS](./docs/backend/BLIS.md) | All |
411411
| [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU |
412+
| [MUSA](./docs/build.md#musa) | Moore Threads GPU |
412413
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
413414
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
414415
| [Vulkan](./docs/build.md#vulkan) | GPU |

common/common.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
13241324
else { invalid_param = true; }
13251325
return true;
13261326
}
1327+
if (arg == "--no-warmup") {
1328+
params.warmup = false;
1329+
return true;
1330+
}
13271331
#ifndef LOG_DISABLE_LOGS
13281332
// Parse args for logging parameters
13291333
if (log_param_single_parse(argv[i])) {
@@ -1446,6 +1450,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
14461450
options.push_back({ "main infill", " --in-prefix-bos", "prefix BOS to user inputs, preceding the `--in-prefix` string" });
14471451
options.push_back({ "main infill", " --in-prefix STRING", "string to prefix user inputs with (default: empty)" });
14481452
options.push_back({ "main infill", " --in-suffix STRING", "string to suffix after user inputs with (default: empty)" });
1453+
options.push_back({ "main", " --no-warmup", "skip warming up the model with an empty run" });
14491454
options.push_back({ "server infill",
14501455
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
14511456

convert_hf_to_gguf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,6 +1570,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
15701570
return [(self.map_tensor_name(name), data_torch)]
15711571

15721572
def prepare_tensors(self):
1573+
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
1574+
if rope_scaling.get("rope_type", '').lower() == "llama3":
1575+
base = self.hparams.get("rope_theta", 10000.0)
1576+
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
1577+
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
1578+
1579+
factor = rope_scaling.get("factor", 8.0)
1580+
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
1581+
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
1582+
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
1583+
1584+
low_freq_wavelen = old_context_len / low_freq_factor
1585+
high_freq_wavelen = old_context_len / high_freq_factor
1586+
assert low_freq_wavelen != high_freq_wavelen
1587+
1588+
rope_factors = []
1589+
for freq in freqs:
1590+
wavelen = 2 * math.pi / freq
1591+
if wavelen < high_freq_wavelen:
1592+
rope_factors.append(1)
1593+
elif wavelen > low_freq_wavelen:
1594+
rope_factors.append(factor)
1595+
else:
1596+
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
1597+
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
1598+
1599+
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
1600+
15731601
super().prepare_tensors()
15741602

15751603
if self._experts is not None:

docs/build.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/c
192192
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
193193
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
194194
195+
### MUSA
196+
197+
- Using `make`:
198+
```bash
199+
make GGML_MUSA=1
200+
```
201+
- Using `CMake`:
202+
203+
```bash
204+
cmake -B build -DGGML_MUSA=ON
205+
cmake --build build --config Release
206+
```
207+
195208
### hipBLAS
196209
197210
This provides BLAS acceleration on HIP-supported AMD GPUs.

examples/eval-callback/eval-callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
6262
} else if (type == GGML_TYPE_I8) {
6363
v = (float) *(int8_t *) &data[i];
6464
} else {
65-
GGML_ASSERT(false);
65+
GGML_ABORT("fatal error");
6666
}
6767
printf("%12.4f", v);
6868
sum += v;

examples/imatrix/imatrix.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
127127
}
128128
else if (e.values.size() != (size_t)src1->ne[0]*n_as) {
129129
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as);
130-
exit(1); //GGML_ASSERT(false);
130+
exit(1); //GGML_ABORT("fatal error");
131131
}
132132
if (m_params.verbosity > 1) {
133133
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
@@ -176,7 +176,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
176176
}
177177
else if (e.values.size() != (size_t)src1->ne[0]) {
178178
fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]);
179-
exit(1); //GGML_ASSERT(false);
179+
exit(1); //GGML_ABORT("fatal error");
180180
}
181181
++e.ncall;
182182
if (m_params.verbosity > 1) {

examples/llama-bench/llama-bench.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ static const char * output_format_str(output_formats format) {
150150
case JSON: return "json";
151151
case MARKDOWN: return "md";
152152
case SQL: return "sql";
153-
default: GGML_ASSERT(!"invalid output format");
153+
default: GGML_ABORT("invalid output format");
154154
}
155155
}
156156

@@ -176,7 +176,7 @@ static const char * split_mode_str(llama_split_mode mode) {
176176
case LLAMA_SPLIT_MODE_NONE: return "none";
177177
case LLAMA_SPLIT_MODE_LAYER: return "layer";
178178
case LLAMA_SPLIT_MODE_ROW: return "row";
179-
default: GGML_ASSERT(!"invalid split mode");
179+
default: GGML_ABORT("invalid split mode");
180180
}
181181
}
182182

@@ -1326,7 +1326,7 @@ static std::unique_ptr<printer> create_printer(output_formats format) {
13261326
case SQL:
13271327
return std::unique_ptr<printer>(new sql_printer());
13281328
}
1329-
GGML_ASSERT(false);
1329+
GGML_ABORT("fatal error");
13301330
}
13311331

13321332
int main(int argc, char ** argv) {

examples/llava/clip.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
869869
embeddings = peg_0;
870870
}
871871
else {
872-
GGML_ASSERT(false);
872+
GGML_ABORT("fatal error");
873873
}
874874
}
875875

examples/save-load-state/save-load-state.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
4747
// save state (rng, logits, embedding and kv_cache) to file
4848
{
4949
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
50-
const size_t written = llama_state_get_data(ctx, state_mem.data());
50+
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
5151

5252
FILE *fp_write = fopen("dump_state.bin", "wb");
5353
fwrite(state_mem.data(), 1, written, fp_write);
@@ -99,13 +99,16 @@ int main(int argc, char ** argv) {
9999

100100
// load state (rng, logits, embedding and kv_cache) from file
101101
{
102-
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
102+
std::vector<uint8_t> state_mem;
103103

104104
FILE * fp_read = fopen("dump_state.bin", "rb");
105+
fseek(fp_read, 0, SEEK_END);
106+
state_mem.resize(ftell(fp_read));
107+
fseek(fp_read, 0, SEEK_SET);
105108
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
106109
fclose(fp_read);
107110

108-
if (read != llama_state_set_data(ctx2, state_mem.data())) {
111+
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
109112
fprintf(stderr, "\n%s : failed to read state\n", __func__);
110113
llama_free(ctx2);
111114
llama_free_model(model);
@@ -159,13 +162,16 @@ int main(int argc, char ** argv) {
159162

160163
// load state (rng, logits, embedding and kv_cache) from file
161164
{
162-
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
165+
std::vector<uint8_t> state_mem;
163166

164167
FILE * fp_read = fopen("dump_state.bin", "rb");
168+
fseek(fp_read, 0, SEEK_END);
169+
state_mem.resize(ftell(fp_read));
170+
fseek(fp_read, 0, SEEK_SET);
165171
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
166172
fclose(fp_read);
167173

168-
if (read != llama_state_set_data(ctx3, state_mem.data())) {
174+
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
169175
fprintf(stderr, "\n%s : failed to read state\n", __func__);
170176
llama_free(ctx3);
171177
llama_free_model(model);
@@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
182188
{
183189
// save kv of seq 0
184190
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
185-
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
191+
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
186192
if (ncopy != seq_store.size()) {
187193
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
188194
llama_free(ctx3);
@@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
196202
fprintf(stderr, "%s : kv cache cleared\n", __func__);
197203

198204
// restore kv into seq 1
199-
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
205+
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
200206
if (nset != seq_store.size()) {
201207
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
202208
llama_free(ctx3);

examples/tokenize/tokenize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
163163
printf(">");
164164
return;
165165
}
166-
GGML_ASSERT(false && "MultiByteToWideChar() failed in an unexpected way.");
166+
GGML_ABORT("MultiByteToWideChar() failed in an unexpected way.");
167167
}
168168

169169
LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr));

ggml/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,15 @@ else()
5050
set(GGML_BLAS_VENDOR_DEFAULT "Generic")
5151
endif()
5252

53+
if (CMAKE_CROSSCOMPILING)
54+
set(GGML_NATIVE_DEFAULT OFF)
55+
else()
56+
set(GGML_NATIVE_DEFAULT ON)
57+
endif()
58+
5359
# general
5460
option(GGML_STATIC "ggml: static link libraries" OFF)
55-
option(GGML_NATIVE "ggml: enable -march=native flag" ON)
61+
option(GGML_NATIVE "ggml: enable -march=native flag" ${GGML_NATIVE_DEFAULT})
5662
option(GGML_LTO "ggml: enable link time optimization" OFF)
5763
option(GGML_CCACHE "ggml: use ccache if available" ON)
5864

@@ -70,7 +76,7 @@ option(GGML_SANITIZE_ADDRESS "ggml: enable address sanitizer" OFF)
7076
option(GGML_SANITIZE_UNDEFINED "ggml: enable undefined sanitizer" OFF)
7177

7278
# instruction set specific
73-
if (GGML_NATIVE)
79+
if (GGML_NATIVE OR NOT GGML_NATIVE_DEFAULT)
7480
set(INS_ENB OFF)
7581
else()
7682
set(INS_ENB ON)
@@ -107,6 +113,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
107113
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
108114

109115
option(GGML_CUDA "ggml: use CUDA" OFF)
116+
option(GGML_MUSA "ggml: use MUSA" OFF)
110117
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
111118
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
112119
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)

ggml/include/ggml-cuda.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#ifdef GGML_USE_HIPBLAS
77
#define GGML_CUDA_NAME "ROCm"
88
#define GGML_CUBLAS_NAME "hipBLAS"
9+
#elif defined(GGML_USE_MUSA)
10+
#define GGML_CUDA_NAME "MUSA"
11+
#define GGML_CUBLAS_NAME "muBLAS"
912
#else
1013
#define GGML_CUDA_NAME "CUDA"
1114
#define GGML_CUBLAS_NAME "cuBLAS"

0 commit comments

Comments
 (0)