Skip to content

Commit b6ce743

Browse files
authored
llama-graph : fix text position for mrope (#13159)
* llama-graph : fix text position for mrope * fix typo * explicitly set 4th dim in the loop
1 parent 5f5e39e commit b6ce743

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/llama-graph.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,16 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
5555
if (ubatch->pos && pos) {
5656
const int64_t n_tokens = ubatch->n_tokens;
5757

58-
if (ubatch->token && n_pos_per_embd > 1) {
58+
if (ubatch->token && n_pos_per_embd == 4) {
5959
// in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
60-
// the other dimensions are all 0, they are unused for text tokens
61-
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd, 0);
60+
// the 3 first dims are the same, and 4th dim is all 0
61+
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
6262
// copy the first dimension
6363
for (int i = 0; i < n_tokens; ++i) {
64-
pos_data[i] = ubatch->pos[i];
64+
pos_data[ i] = ubatch->pos[i];
65+
pos_data[ n_tokens + i] = ubatch->pos[i];
66+
pos_data[2 * n_tokens + i] = ubatch->pos[i];
67+
pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
6568
}
6669
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
6770
} else {

0 commit comments

Comments
 (0)