From b97bc3966e852adb626c90be64fd48282800f504 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 21 Apr 2024 13:50:41 +0200 Subject: [PATCH] llama : support Llama 3 HF conversion (#6745) * Support Llama 3 conversion The tokenizer is BPE. * style * Accept suggestion Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> * llama : add llama_token_is_eog() ggml-ci * llama : auto-detect more EOT tokens when missing in KV data * convert : replacing EOS token is a hack * llama : fix codegemma EOT token + add TODOs * llama : fix model type string for 8B model --------- Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Co-authored-by: Georgi Gerganov --- convert-hf-to-gguf.py | 47 ++++++++---- convert.py | 9 ++- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 4 +- examples/beam-search/beam-search.cpp | 2 +- examples/infill/infill.cpp | 10 +-- .../app/src/main/cpp/llama-android.cpp | 2 +- .../llama.cpp.swift/LibLlama.swift | 2 +- examples/llava/llava-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 2 +- examples/lookup/lookup.cpp | 2 +- examples/main/main.cpp | 8 +-- examples/parallel/parallel.cpp | 2 +- examples/passkey/passkey.cpp | 4 +- examples/server/server.cpp | 2 +- examples/server/utils.hpp | 4 -- examples/simple/simple.cpp | 4 +- examples/speculative/speculative.cpp | 2 +- llama.cpp | 72 ++++++++++++++----- llama.h | 5 +- 20 files changed, 123 insertions(+), 64 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 358dba8ed..4fd916cba 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1301,15 +1301,23 @@ class LlamaModel(Model): try: self. _set_vocab_sentencepiece() except FileNotFoundError: - self._set_vocab_llama_hf() + try: + self._set_vocab_llama_hf() + except (FileNotFoundError, TypeError): + # Llama 3 + self._set_vocab_gpt2() - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, - special_token_types = ['prefix', 'suffix', 'middle', 'eot']) - special_vocab._set_special_token("prefix", 32007) - special_vocab._set_special_token("suffix", 32008) - special_vocab._set_special_token("middle", 32009) - special_vocab._set_special_token("eot", 32010) - special_vocab.add_to_gguf(self.gguf_writer) + # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256) + if self.hparams.get("vocab_size", 32000) == 32016: + special_vocab = gguf.SpecialVocab( + self.dir_model, load_merges=False, + special_token_types = ['prefix', 'suffix', 'middle', 'eot'] + ) + special_vocab._set_special_token("prefix", 32007) + special_vocab._set_special_token("suffix", 32008) + special_vocab._set_special_token("middle", 32009) + special_vocab._set_special_token("eot", 32010) + special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): super().set_gguf_parameters() @@ -2194,6 +2202,8 @@ class InternLM2Model(Model): old_eos = special_vocab.special_token_ids["eos"] if "chat" in os.path.basename(self.dir_model.absolute()): # For the chat model, we replace the eos with '<|im_end|>'. + # TODO: this is a hack, should be fixed + # https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048 special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer) print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \ in chat mode so that the conversation can end normally.") @@ -2429,12 +2439,15 @@ class GemmaModel(Model): def set_vocab(self): self._set_vocab_sentencepiece() + + # TODO: these special tokens should be exported only for the CodeGemma family special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, - special_token_types = ['prefix', 'suffix', 'middle', 'eot']) + special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) special_vocab._set_special_token("prefix", 67) special_vocab._set_special_token("suffix", 69) special_vocab._set_special_token("middle", 68) - special_vocab._set_special_token("eot", 70) + special_vocab._set_special_token("fsep", 70) + special_vocab._set_special_token("eot", 107) special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): @@ -2523,28 +2536,34 @@ class MambaModel(Model): field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL) self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1])) + field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST) self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES) self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID) self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID) self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0]) + field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID) self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0]) def set_gguf_parameters(self): - d_model = self.find_hparam(["hidden_size", "d_model"]) - d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_model = self.find_hparam(["hidden_size", "d_model"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model - d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16 + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16 # ceiling division # ref: https://stackoverflow.com/a/17511341/22827863 # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 - dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) + dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 # Fail early for models which don't have a block expansion factor of 2 diff --git a/convert.py b/convert.py index 24df0a4d8..1c700cf6a 100755 --- a/convert.py +++ b/convert.py @@ -525,7 +525,14 @@ class LlamaHfVocab(Vocab): # pre-check so we know if we need transformers tokenizer_model: dict[str, Any] = tokenizer_json['model'] - if ( + is_llama3 = ( + tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False) + and not tokenizer_model.get('byte_fallback', True) + ) + if is_llama3: + raise TypeError('Llama 3 must be converted with BpeVocab') + + if not is_llama3 and ( tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) or tokenizer_json['decoder']['type'] != 'Sequence' ): diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index d75c503d5..5764acb6d 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -153,7 +153,7 @@ while n_cur <= n_len { // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); // is it an end of stream? -> mark the stream as finished - if new_token_id == llama_token_eos(model) || n_cur == n_len { + if llama_token_is_eog(model, new_token_id) || n_cur == n_len { i_batch[i] = -1 // print("") if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 7aaf63ceb..be30d20bf 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -191,8 +191,8 @@ int main(int argc, char ** argv) { //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); - // is it an end of stream? -> mark the stream as finished - if (new_token_id == llama_token_eos(model) || n_cur == n_len) { + // is it an end of generation? -> mark the stream as finished + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { i_batch[i] = -1; LOG_TEE("\n"); if (n_parallel > 1) { diff --git a/examples/beam-search/beam-search.cpp b/examples/beam-search/beam-search.cpp index 866c6d7a6..3d34378a5 100644 --- a/examples/beam-search/beam-search.cpp +++ b/examples/beam-search/beam-search.cpp @@ -47,7 +47,7 @@ struct beam_search_callback_data { // In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same. // For example, eob can be flagged due to maximum token length, stop words, etc. static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) { - return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx)); + return n_tokens && llama_token_is_eog(llama_get_model(callback_data.ctx), tokens[n_tokens-1]); } // Function matching type llama_beam_search_callback_fn_t. diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index c69dcd06e..afac145f6 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -586,7 +586,7 @@ int main(int argc, char ** argv) { // deal with eot token in infill mode if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){ - if(is_interacting && !params.interactive_first) { + if (is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); } @@ -651,8 +651,8 @@ int main(int argc, char ** argv) { // LOG_TEE("took new input\n"); is_interacting = false; } - // deal with end of text token in interactive mode - else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) { + // deal with end of generation tokens in interactive mode + else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { LOG("found EOS token\n"); if (params.interactive) { @@ -731,8 +731,8 @@ int main(int argc, char ** argv) { } } - // end of text token - if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) { + // end of generation + if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) { break; } diff --git a/examples/llama.android/app/src/main/cpp/llama-android.cpp b/examples/llama.android/app/src/main/cpp/llama-android.cpp index ce8ab3b70..4af9de303 100644 --- a/examples/llama.android/app/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/app/src/main/cpp/llama-android.cpp @@ -408,7 +408,7 @@ Java_com_example_llama_Llm_completion_1loop( const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (new_token_id == llama_token_eos(model) || n_cur == n_len) { + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { return env->NewStringUTF(""); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index c249291ae..70c43a385 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -158,7 +158,7 @@ actor LlamaContext { new_token_id = llama_sample_token_greedy(context, &candidates_p) } - if new_token_id == llama_token_eos(model) || n_cur == n_len { + if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") let new_token_str = String(cString: temporary_invalid_cchars + [0]) temporary_invalid_cchars.removeAll() diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 75948806e..50dac4cae 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -45,7 +45,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); llama_sampling_accept(ctx_sampling, ctx_llama, id, true); static std::string ret; - if (id == llama_token_eos(llama_get_model(ctx_llama))) { + if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; } else { ret = llama_token_to_piece(ctx_llama, id); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 5af6a8ab6..9c3540b20 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -299,7 +299,7 @@ int main(int argc, char ** argv) { } fflush(stdout); - if (id == llama_token_eos(model)) { + if (llama_token_is_eog(model, id)) { has_eos = true; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 65ed408a2..9526e898f 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -141,7 +141,7 @@ int main(int argc, char ** argv){ printf("%s", token_str.c_str()); } - if (id == llama_token_eos(model)) { + if (llama_token_is_eog(model, id)) { has_eos = true; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 249fc2bb6..1180734b9 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -795,8 +795,8 @@ int main(int argc, char ** argv) { } } - // deal with end of text token in interactive mode - if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) { + // deal with end of generation tokens in interactive mode + if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { LOG("found EOS token\n"); if (params.interactive) { @@ -920,8 +920,8 @@ int main(int argc, char ** argv) { } } - // end of text token - if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) { + // end of generation + if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.instruct || params.interactive || params.chatml)) { LOG_TEE(" [end of text]\n"); break; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index f66c91013..7c5595d6e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -359,7 +359,7 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); if (client.n_decoded > 2 && - (id == llama_token_eos(model) || + (llama_token_is_eog(model, id) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || client.response.find("User:") != std::string::npos || client.response.find('\n') != std::string::npos)) { diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 2cbc9e1fa..f2ef9ca10 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -252,8 +252,8 @@ int main(int argc, char ** argv) { // sample the most likely token const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); - // is it an end of stream? - if (new_token_id == llama_token_eos(model) || n_cur == n_len) { + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { LOG_TEE("\n"); break; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 634e653ad..25bc29639 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1201,7 +1201,7 @@ struct server_context { }); } - if (result.tok == llama_token_eos(model)) { + if (llama_token_is_eog(model, result.tok)) { slot.stopped_eos = true; slot.has_next_token = false; diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index a8d43ac63..1a2212502 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -381,10 +381,6 @@ static json oaicompat_completion_params_parse( } else { llama_params["stop"] = json_value(body, "stop", json::array()); } - // Some chat templates don't use EOS token to stop generation - // We must add their end sequences to list of stop words - llama_params["stop"].push_back("<|im_end|>"); // chatml - llama_params["stop"].push_back(""); // gemma // Handle "response_format" field if (body.contains("response_format")) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 39e2d8ea4..b0f8e0fdc 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -133,8 +133,8 @@ int main(int argc, char ** argv) { // sample the most likely token const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); - // is it an end of stream? - if (new_token_id == llama_token_eos(model) || n_cur == n_len) { + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { LOG_TEE("\n"); break; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 6a7367b0c..12e46fbc9 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -360,7 +360,7 @@ int main(int argc, char ** argv) { } } - if (token_id == llama_token_eos(model_tgt)) { + if (llama_token_is_eog(model_tgt, token_id)) { has_eos = true; } ++n_predict; diff --git a/llama.cpp b/llama.cpp index fa7c022f2..8ca9650de 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2120,7 +2120,7 @@ struct llama_vocab { id special_prefix_id = -1; id special_suffix_id = -1; id special_middle_id = -1; - id special_eot_id = -1; + id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token bool add_space_prefix = true; @@ -3770,7 +3770,7 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; + case 32: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_7B : e_model::MODEL_8B; break; // LLaMa 8B v3 uses GQA case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; @@ -4179,7 +4179,10 @@ static void llm_load_vocab( vocab.special_prefix_id = 67; vocab.special_suffix_id = 69; vocab.special_middle_id = 68; - vocab.special_eot_id = 70; + // TODO: this is not EOT, it is "file separator" token, needs fix + // https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572 + //vocab.special_eot_id = 70; + vocab.special_eot_id = 107; } } @@ -4308,6 +4311,7 @@ static void llm_load_vocab( { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id }, { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, }; + for (const auto & it : special_token_types) { const std::string & key = kv(std::get<0>(it)); int32_t & id = std::get<1>(it); @@ -4322,7 +4326,6 @@ static void llm_load_vocab( } else { id = new_id; } - } // Handle add_bos_token and add_eos_token @@ -4336,6 +4339,27 @@ static void llm_load_vocab( vocab.special_add_eos = int(temp); } } + + // find EOT token: "<|eot_id|>", "<|im_emd|>", "", etc. + // + // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID + // for now, we apply this workaround to find the EOT token based on its text + if (vocab.special_eot_id == -1) { + for (const auto & t : vocab.token_to_id) { + if ( + // TODO: gemma "" is exported as a normal token, so the following check does not work + // need to fix convert script + //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL && + (t.first == "<|eot_id|>" || + t.first == "<|im_emd|>" || + t.first == "" + ) + ) { + vocab.special_eot_id = t.second; + break; + } + } + } } // build special tokens cache @@ -4498,14 +4522,19 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); // special tokens - if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } - if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } - if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } - if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } - if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } - if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); } - if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); } - if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } + if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.special_cls_id != -1) { LLAMA_LOG_INFO( "%s: CLS token = %d '%s'\n", __func__, vocab.special_cls_id, vocab.id_to_token[vocab.special_cls_id].text.c_str() ); } + if (vocab.special_mask_id != -1) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); } + + if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } + if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); } + if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); } + if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } + if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } } // Returns false if cancelled by progress_callback @@ -13268,16 +13297,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); - bool allow_eos = false; + bool allow_eog = false; for (const auto & stack : grammar->stacks) { if (stack.empty()) { - allow_eos = true; + allow_eog = true; break; } } - const llama_token eos = llama_token_eos(&ctx->model); - std::vector, llama_partial_utf8>> candidates_decoded; candidates_decoded.reserve(candidates->size); std::vector candidates_grammar; @@ -13286,8 +13313,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; const std::string piece = llama_token_to_piece(ctx, id); - if (id == eos) { - if (!allow_eos) { + if (llama_token_is_eog(&ctx->model, id)) { + if (!allow_eog) { candidates->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { @@ -13476,7 +13503,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - if (token == llama_token_eos(&ctx->model)) { + if (llama_token_is_eog(&ctx->model, token)) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { return; @@ -16880,6 +16907,13 @@ llama_token_type llama_token_get_type(const struct llama_model * model, llama_to return model->vocab.id_to_token[token].type; } +bool llama_token_is_eog(const struct llama_model * model, llama_token token) { + return token != -1 && ( + token == llama_token_eos(model) || + token == llama_token_eot(model) + ); +} + llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } diff --git a/llama.h b/llama.h index b5da686f7..5bed97ad1 100644 --- a/llama.h +++ b/llama.h @@ -783,6 +783,9 @@ extern "C" { LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); + // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) + LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence @@ -796,7 +799,7 @@ extern "C" { // Returns -1 if unknown, 1 for true or 0 for false. LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model); - // codellama infill tokens + // Codellama infill tokens LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix