diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 5cb3e63fb..4d995ef78 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -10,7 +10,7 @@ import re import sys from enum import IntEnum from pathlib import Path -from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional +from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast import numpy as np import torch @@ -487,7 +487,8 @@ class MPTModel(Model): # map tensor names if "scales" in name: new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales")) - new_name = new_name.replace("scales", "act.scales") + if new_name is not None: + new_name = new_name.replace("scales", "act.scales") else: new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) if new_name is None: @@ -904,7 +905,7 @@ class QwenModel(Model): return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) @staticmethod - def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]: + def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: parts = [bytes([b]) for b in token] while True: min_idx = None @@ -1285,7 +1286,7 @@ def main() -> None: if args.awq_path: sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) - from awq.apply_awq import add_scale_weights + from awq.apply_awq import add_scale_weights # type: ignore[import-not-found] tmp_model_path = args.model / "weighted_model" dir_model = tmp_model_path if tmp_model_path.is_dir(): diff --git a/convert-llama-ggml-to-gguf.py b/convert-llama-ggml-to-gguf.py index e359330af..b33108062 100755 --- a/convert-llama-ggml-to-gguf.py +++ b/convert-llama-ggml-to-gguf.py @@ -2,6 +2,7 @@ from __future__ import annotations import argparse +import os import struct import sys from enum import IntEnum @@ -9,7 +10,6 @@ from pathlib import Path import numpy as np -import os if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf @@ -371,15 +371,11 @@ def handle_metadata(cfg, hp): params = convert.Params.loadOriginalParamsJson(fakemodel, orig_config_path) else: raise ValueError('Unable to load metadata') - vocab = convert.load_vocab( - cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, - cfg.vocabtype) - # FIXME: Respect cfg.vocab_dir? - svocab = gguf.SpecialVocab(cfg.model_metadata_dir, - load_merges = cfg.vocabtype == 'bpe', - n_vocab = vocab.vocab_size) + vocab_path = Path(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir) + vocab_factory = convert.VocabFactory(vocab_path) + vocab, special_vocab = vocab_factory.load_vocab(cfg.vocabtype, cfg.model_metadata_dir) convert.check_vocab_size(params, vocab) - return (params, vocab, svocab) + return params, vocab, special_vocab def handle_args(): diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py index 35ce152f4..4904bf128 100755 --- a/convert-lora-to-ggml.py +++ b/convert-lora-to-ggml.py @@ -5,17 +5,16 @@ import json import os import struct import sys +from pathlib import Path from typing import Any, BinaryIO, Sequence import numpy as np import torch -from pathlib import Path if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) import gguf - NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1} diff --git a/convert-persimmon-to-gguf.py b/convert-persimmon-to-gguf.py index 1ba5864dc..d2be805d1 100755 --- a/convert-persimmon-to-gguf.py +++ b/convert-persimmon-to-gguf.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 -import torch -import os -from pprint import pprint -import sys import argparse +import os +import sys from pathlib import Path +from pprint import pprint + +import torch from sentencepiece import SentencePieceProcessor + if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf @@ -69,7 +71,7 @@ def main(): persimmon_model = torch.load(args.ckpt_path) hparams = persimmon_model['args'] pprint(hparams) - tensors = {} + tensors: dict[str, torch.Tensor] = {} _flatten_dict(persimmon_model['model'], tensors, None) arch = gguf.MODEL_ARCH.PERSIMMON diff --git a/convert.py b/convert.py index 980e6fc72..06768033d 100755 --- a/convert.py +++ b/convert.py @@ -17,58 +17,28 @@ import signal import struct import sys import time -import warnings import zipfile from abc import ABCMeta, abstractmethod -from argparse import ArgumentParser from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import ( - IO, - TYPE_CHECKING, - Any, - Callable, - Iterable, - Literal, - Optional, - Tuple, - TypeVar, -) +from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar import numpy as np from sentencepiece import SentencePieceProcessor -try: - from transformers import AutoTokenizer -except ModuleNotFoundError as e: - warnings.warn(f"Could not import AutoTokenizer from transformers: {e}") +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf -# If NO_LOCAL_GGUF is not set, try to import gguf from the local gguf-py directory -if "NO_LOCAL_GGUF" not in os.environ: - # Use absolute path to the gguf-py directory - gguf_py_dir = str(Path(__file__).resolve().parent / "gguf-py") - print(gguf_py_dir) # NOTE: Remove this once path is verified after changes are completed - if gguf_py_dir not in sys.path: - sys.path.insert(1, gguf_py_dir) +if TYPE_CHECKING: + from typing import TypeAlias -# Import gguf module -try: - import gguf -except ModuleNotFoundError as e: - print(f"Could not import gguf: {e}") - sys.exit(1) - -if TYPE_CHECKING: # NOTE: This isn't necessary. - from typing import TypeAlias # This can technically be omitted. - -if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"): +if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): faulthandler.register(signal.SIGUSR1) -# NOTE: n-dimensional arrays should be directly referenced -NDArray: TypeAlias = "np.ndarray[Any, Any]" +NDArray: TypeAlias = 'np.ndarray[Any, Any]' -# Why is this here? LLAMA and GPT are technically the only compatible ARCHs. ARCH = gguf.MODEL_ARCH.LLAMA DEFAULT_CONCURRENCY = 8 @@ -78,7 +48,6 @@ DEFAULT_CONCURRENCY = 8 # -# TODO: Clean up and refactor data types @dataclass(frozen=True) class DataType: name: str @@ -183,85 +152,65 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { @dataclass class Params: - n_vocab: int - n_embd: int - n_layer: int - n_ctx: int - n_ff: int - n_head: int - n_head_kv: int - f_norm_eps: Optional[float] = None - n_experts: Optional[int] = None - n_experts_used: Optional[int] = None + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + n_experts: int | None = None + n_experts_used: int | None = None + f_norm_eps: float | None = None - rope_scaling_type: Optional[gguf.RopeScalingType] = None - f_rope_freq_base: Optional[float] = None - f_rope_scale: Optional[float] = None - n_orig_ctx: Optional[int] = None - rope_finetuned: Optional[bool] = None + rope_scaling_type: gguf.RopeScalingType | None = None + f_rope_freq_base: float | None = None + f_rope_scale: float | None = None + n_orig_ctx: int | None = None + rope_finetuned: bool | None = None - ftype: Optional[GGMLFileType] = None + ftype: GGMLFileType | None = None # path to the directory containing the model files - path_model: Optional[Path] = None + path_model: Path | None = None @staticmethod - def guessed(model: LazyModel) -> "Params": + def guessed(model: LazyModel) -> Params: # try transformer naming first - n_vocab, n_embd = ( - model["model.embed_tokens.weight"].shape - if "model.embed_tokens.weight" in model - else model["tok_embeddings.weight"].shape - ) + n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape # try transformer naming first if "model.layers.0.self_attn.q_proj.weight" in model: - n_layer = next( - i - for i in itertools.count() - if f"model.layers.{i}.self_attn.q_proj.weight" not in model - ) - elif ( - "model.layers.0.self_attn.W_pack.weight" in model - ): # next: try baichuan naming - n_layer = next( - i - for i in itertools.count() - if f"model.layers.{i}.self_attn.W_pack.weight" not in model - ) + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) + elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) else: - n_layer = next( - i - for i in itertools.count() - if f"layers.{i}.attention.wq.weight" not in model - ) + n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) if n_layer < 1: - raise Exception( - "failed to guess 'n_layer'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files." - ) + raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files.") - n_head = n_embd // 128 # guessed - n_mult = 256 # guessed + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed # TODO: verify this n_ff = int(2 * (4 * n_embd) / 3) n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) return Params( - n_vocab=n_vocab, - n_embd=n_embd, - n_layer=n_layer, - n_ctx=-1, - n_ff=n_ff, - n_head=n_head, - n_head_kv=n_head, - f_norm_eps=1e-5, + n_vocab = n_vocab, + n_embd = n_embd, + n_layer = n_layer, + n_ctx = -1, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head, + f_norm_eps = 1e-5, ) @staticmethod - def load_transformers_config(model: LazyModel, config_path: Path) -> "Params": + def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: config = json.load(open(config_path)) rope_scaling_type = f_rope_scale = n_orig_ctx = rope_finetuned = None @@ -274,22 +223,20 @@ class Params: rope_scaling_type = gguf.RopeScalingType.LINEAR elif typ == "yarn": rope_scaling_type = gguf.RopeScalingType.YARN - n_orig_ctx = rope_scaling["original_max_position_embeddings"] - rope_finetuned = rope_scaling["finetuned"] + n_orig_ctx = rope_scaling['original_max_position_embeddings'] + rope_finetuned = rope_scaling['finetuned'] else: - raise NotImplementedError(f"Unknown rope scaling type: {typ}") + raise NotImplementedError(f'Unknown rope scaling type: {typ}') if "max_sequence_length" in config: n_ctx = config["max_sequence_length"] elif "max_position_embeddings" in config: n_ctx = config["max_position_embeddings"] else: - raise Exception( - "failed to guess 'n_ctx'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files." - ) + raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files.") - n_experts = None + n_experts = None n_experts_used = None if "num_local_experts" in config: @@ -297,30 +244,30 @@ class Params: n_experts_used = config["num_experts_per_tok"] return Params( - n_vocab=config["vocab_size"], - n_embd=config["hidden_size"], - n_layer=config["num_hidden_layers"], - n_ctx=n_ctx, - n_ff=config["intermediate_size"], - n_head=(n_head := config["num_attention_heads"]), - n_head_kv=config.get("num_key_value_heads", n_head), - n_experts=n_experts, - n_experts_used=n_experts_used, - f_norm_eps=config["rms_norm_eps"], - f_rope_freq_base=config.get("rope_theta"), - rope_scaling_type=rope_scaling_type, - f_rope_scale=f_rope_scale, - n_orig_ctx=n_orig_ctx, - rope_finetuned=rope_finetuned, + n_vocab = config["vocab_size"], + n_embd = config["hidden_size"], + n_layer = config["num_hidden_layers"], + n_ctx = n_ctx, + n_ff = config["intermediate_size"], + n_head = (n_head := config["num_attention_heads"]), + n_head_kv = config.get("num_key_value_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["rms_norm_eps"], + f_rope_freq_base = config.get("rope_theta"), + rope_scaling_type = rope_scaling_type, + f_rope_scale = f_rope_scale, + n_orig_ctx = n_orig_ctx, + rope_finetuned = rope_finetuned, ) # LLaMA v2 70B params.json # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} @staticmethod - def load_torch_params(model: LazyModel, config_path: Path) -> "Params": + def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: config = json.load(open(config_path)) - n_experts = None + n_experts = None n_experts_used = None f_rope_freq_base = None @@ -343,50 +290,50 @@ class Params: if config.get("moe"): n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] - n_experts = config["moe"]["num_experts"] + n_experts = config["moe"]["num_experts"] n_experts_used = config["moe"]["num_experts_per_tok"] f_rope_freq_base = 1e6 return Params( - n_vocab=model["tok_embeddings.weight"].shape[0], - n_embd=config["dim"], - n_layer=config["n_layers"], - n_ctx=n_ctx, - n_ff=n_ff, - n_head=(n_head := config["n_heads"]), - n_head_kv=config.get("n_kv_heads", n_head), - n_experts=n_experts, - n_experts_used=n_experts_used, - f_norm_eps=config["norm_eps"], - f_rope_freq_base=config.get("rope_theta", f_rope_freq_base), + n_vocab = model["tok_embeddings.weight"].shape[0], + n_embd = config["dim"], + n_layer = config["n_layers"], + n_ctx = n_ctx, + n_ff = n_ff, + n_head = (n_head := config["n_heads"]), + n_head_kv = config.get("n_kv_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["norm_eps"], + f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), ) @staticmethod - def load(model_plus: ModelPlus) -> "Params": - hf_config_path = model_plus.paths[0].parent / "config.json" + def load(model_plus: ModelPlus) -> Params: + hf_config_path = model_plus.paths[0].parent / "config.json" orig_config_path = model_plus.paths[0].parent / "params.json" if hf_config_path.exists(): - params = Params.load_transformers_config(model_plus.model, hf_config_path) + params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) elif orig_config_path.exists(): - params = Params.load_torch_params(model_plus.model, orig_config_path) - elif model_plus.format != "none": + params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) + elif model_plus.format != 'none': params = Params.guessed(model_plus.model) else: - raise ValueError("Cannot guess params when model format is none") + raise ValueError('Cannot guess params when model format is none') params.path_model = model_plus.paths[0].parent return params -class BpeVocab: # GPT - def __init__( - self, fname_tokenizer: Path, fname_added_tokens: Optional[Path] - ) -> None: - self.bpe_tokenizer = json.loads( - open(str(fname_tokenizer), encoding="utf-8").read() - ) +# +# vocab +# + +class BpeVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: + self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) self.vocab = self.bpe_tokenizer["model"]["vocab"] added_tokens: dict[str, int] if fname_added_tokens is not None: @@ -394,34 +341,31 @@ class BpeVocab: # GPT added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: # Fall back to trying to find the added tokens in tokenizer.json - tokenizer_json_file = fname_tokenizer.parent / "tokenizer.json" + tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json' if not tokenizer_json_file.is_file(): added_tokens = {} else: tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) added_tokens = dict( - (item["content"], item["id"]) - for item in tokenizer_json.get("added_tokens", []) + (item['content'], item['id']) + for item in tokenizer_json.get('added_tokens', []) # Added tokens here can be duplicates of the main vocabulary. - if item["content"] not in self.bpe_tokenizer - ) + if item['content'] not in self.bpe_tokenizer) vocab_size: int = len(self.vocab) - expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: expected_end_id = vocab_size + len(actual_ids) - 1 - raise Exception( - f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}" - ) + raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}") items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) - self.added_tokens_dict = added_tokens - self.added_tokens_list = [text for (text, idx) in items] + self.added_tokens_dict = added_tokens + self.added_tokens_list = [text for (text, idx) in items] self.vocab_size_base: int = vocab_size - self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens + self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} @@ -442,10 +386,8 @@ class BpeVocab: # GPT return f"" -class SentencePieceVocab: # LlaMa - def __init__( - self, fname_tokenizer: Path, fname_added_tokens: Optional[Path] - ) -> None: +class SentencePieceVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) added_tokens: dict[str, int] if fname_added_tokens is not None: @@ -455,23 +397,19 @@ class SentencePieceVocab: # LlaMa vocab_size: int = self.sentencepiece_tokenizer.vocab_size() - new_tokens = { - id: piece for piece, id in added_tokens.items() if id >= vocab_size - } + new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) - actual_new_ids = sorted(new_tokens.keys()) + actual_new_ids = sorted(new_tokens.keys()) if expected_new_ids != actual_new_ids: - raise ValueError( - f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}" - ) + raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") # Token pieces that were added to the base vocabulary. self.added_tokens_dict = added_tokens - self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: @@ -512,11 +450,15 @@ class SentencePieceVocab: # LlaMa class HfVocab: - def __init__( - self, - fname_tokenizer: Path, - fname_added_tokens: Optional[Path] = None, - ) -> None: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None: + try: + from transformers import AutoTokenizer + except ImportError as e: + raise ImportError( + "To use HfVocab, please install the `transformers` package. " + "You can install it with `pip install transformers`." + ) from e + print("fname_tokenizer:", fname_tokenizer) # Allow the tokenizer to default to slow or fast versions. # Explicitly set tokenizer to use local paths. @@ -529,7 +471,7 @@ class HfVocab: # Initialize lists and dictionaries for added tokens self.added_tokens_list = [] self.added_tokens_dict = dict() - self.added_tokens_ids = set() + self.added_tokens_ids = set() # Process added tokens for tok, tokidx in sorted( @@ -550,12 +492,12 @@ class HfVocab: # Set vocabulary sizes self.vocab_size_base = self.tokenizer.vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens - def hf_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: reverse_vocab = { id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() } @@ -573,11 +515,9 @@ class HfVocab: token_id, self.special_ids # Reuse already stored special IDs ) - def get_token_type(self, token_id: int, special_ids: set) -> gguf.TokenType: + def get_token_type(self, token_id: int, special_ids: set[int]) -> gguf.TokenType: # Determine token type based on whether it's a special token - return ( - gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL - ) + return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL def get_token_score(self, token_id: int) -> float: # Placeholder for actual logic to determine the token's score @@ -589,7 +529,6 @@ class HfVocab: if text in self.specials: toktype = self.get_token_type(self.specials[text], self.special_ids) score = self.get_token_score(self.specials[text]) - else: toktype = gguf.TokenType.USER_DEFINED score = -1000.0 @@ -783,7 +722,7 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: else: model = merge_sharded([mp.model for mp in models_plus]) - return ModelPlus(model, paths, format, vocab) + return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: @@ -871,17 +810,13 @@ class LazyUnpickler(pickle.Unpickler): CLASSES: dict[tuple[str, str], Any] = { # getattr used here as a workaround for mypy not being smart enough to determine # the staticmethods have a __func__ attribute. - ("torch._tensor", "_rebuild_from_type_v2"): getattr( - rebuild_from_type_v2, "__func__" - ), - ("torch._utils", "_rebuild_tensor_v2"): getattr( - lazy_rebuild_tensor_v2, "__func__" - ), - ("torch", "BFloat16Storage"): LazyStorageKind(DT_BF16), - ("torch", "HalfStorage"): LazyStorageKind(DT_F16), - ("torch", "FloatStorage"): LazyStorageKind(DT_F32), - ("torch", "IntStorage"): LazyStorageKind(DT_I32), - ("torch", "Tensor"): LazyTensor, + ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), + ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), + ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), + ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), + ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), + ('torch', 'IntStorage'): LazyStorageKind(DT_I32), + ('torch', 'Tensor'): LazyTensor, } def find_class(self, module: str, name: str) -> Any: @@ -968,7 +903,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc executor_class = ProcessPoolExecutor else: executor_class = ThreadPoolExecutor - with executor_class(max_workers = max_workers) as executor: + with executor_class(max_workers=max_workers) as executor: futures: list[concurrent.futures.Future[Out]] = [] done = False for _ in range(concurrency): @@ -1022,12 +957,8 @@ def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> N class OutputFile: - def __init__( - self, fname_out: Path, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE - ) -> None: - self.gguf = gguf.GGUFWriter( - fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess - ) + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) def add_meta_arch(self, params: Params) -> None: name = "LLaMA" @@ -1036,21 +967,16 @@ class OutputFile: if params.n_ctx == 4096: name = "LLaMA v2" elif params.path_model is not None: - name = str(params.path_model.parent).split("/")[-1] + name = str(params.path_model.parent).split('/')[-1] - self.gguf.add_name(name) - self.gguf.add_context_length(params.n_ctx) - self.gguf.add_embedding_length(params.n_embd) - self.gguf.add_block_count(params.n_layer) - self.gguf.add_feed_forward_length(params.n_ff) + self.gguf.add_name (name) + self.gguf.add_context_length (params.n_ctx) + self.gguf.add_embedding_length (params.n_embd) + self.gguf.add_block_count (params.n_layer) + self.gguf.add_feed_forward_length (params.n_ff) self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) - self.gguf.add_head_count(params.n_head) - self.gguf.add_head_count_kv(params.n_head_kv) - - if params.f_norm_eps is None: - raise ValueError("f_norm_eps is None") - - self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) + self.gguf.add_head_count (params.n_head) + self.gguf.add_head_count_kv (params.n_head_kv) if params.n_experts: self.gguf.add_expert_count(params.n_experts) @@ -1058,6 +984,11 @@ class OutputFile: if params.n_experts_used: self.gguf.add_expert_used_count(params.n_experts_used) + if params.f_norm_eps: + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) + else: + raise ValueError('f_norm_eps is None') + if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) @@ -1089,7 +1020,7 @@ class OutputFile: return tokenizer_model - def extract_vocabulary_from_model(self, vocab: Vocab) -> Tuple[list, list, list]: + def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: tokens = [] scores = [] toktypes = [] @@ -1124,14 +1055,10 @@ class OutputFile: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = int(np.prod(tensor.shape)) - raw_dtype = getattr(tensor.data_type, "ggml_type", None) - data_type = ( - getattr(tensor.data_type, "quantized_type", None) or tensor.data_type.dtype - ) + raw_dtype = getattr(tensor.data_type, 'ggml_type', None) + data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype data_nbytes = tensor.data_type.elements_to_bytes(n_elements) - self.gguf.add_tensor_info( - name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype - ) + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) def write_meta(self) -> None: self.gguf.write_header_to_file() @@ -1145,14 +1072,10 @@ class OutputFile: @staticmethod def write_vocab_only( - fname_out: Path, - params: Params, - vocab: Vocab, - svocab: gguf.SpecialVocab, - endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, - pad_vocab: bool = False, + fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - check_vocab_size(params, vocab, pad_vocab=pad_vocab) + check_vocab_size(params, vocab, pad_vocab = pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -1180,14 +1103,8 @@ class OutputFile: @staticmethod def write_all( - fname_out: Path, - ftype: GGMLFileType, - params: Params, - model: LazyModel, - vocab: Vocab, - svocab: gguf.SpecialVocab, - concurrency: int = DEFAULT_CONCURRENCY, - endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) @@ -1207,26 +1124,19 @@ class OutputFile: of.write_tensor_info() # tensor data - ndarrays_inner = bounded_parallel_map( - OutputFile.do_item, model.items(), concurrency=concurrency - ) + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) if ftype == GGMLFileType.MostlyQ8_0: ndarrays = bounded_parallel_map( - OutputFile.maybe_do_quantize, - ndarrays_inner, - concurrency=concurrency, - max_workers=concurrency, + OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, use_processpool_executor=True, ) else: ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) start = time.time() - for i, ((name, lazy_tensor), ndarray) in enumerate( - zip(model.items(), ndarrays) - ): + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): elapsed = time.time() - start - size = " x ".join(f"{dim:6d}" for dim in lazy_tensor.shape) + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) print( f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" @@ -1363,7 +1273,7 @@ def load_some_model(path: Path) -> ModelPlus: class VocabFactory: def __init__(self, path: Path): self.path = path - self.files = { + self.files: dict[str, Path | None] = { "tokenizer.model": None, "vocab.json": None, "tokenizer.json": None, @@ -1380,24 +1290,18 @@ class VocabFactory: self.files[file] = parent_file_path print(f"Found vocab files: {self.files}") - def _select_file(self, vocabtype: Optional[str]) -> Path: + def _select_file(self, vocabtype: str | None) -> Path: if vocabtype in ["spm", "bpe"]: for file_key in self.files.keys(): - if self.files[file_key]: - return self.files[file_key] + if (file := self.files[file_key]) is not None: + return file raise FileNotFoundError(f"{vocabtype} vocab not found.") - elif vocabtype == "hfft": + if vocabtype == "hfft": # For Hugging Face Fast Tokenizer, return the directory path instead of a specific file return self.path - else: - raise ValueError(f"Unsupported vocabulary type {vocabtype}") + raise ValueError(f"Unsupported vocabulary type {vocabtype}") - def _create_special_vocab( - self, - vocab: Vocab, - vocabtype: str, - model_parent_path: Path, - ) -> gguf.SpecialVocab: + def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab: load_merges = vocabtype == "bpe" n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None return gguf.SpecialVocab( @@ -1407,13 +1311,12 @@ class VocabFactory: n_vocab=n_vocab, ) - def load_vocab( - self, vocabtype: str, model_parent_path: Path - ) -> Tuple[Vocab, gguf.SpecialVocab]: + def load_vocab(self, vocabtype: str, model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: path = self._select_file(vocabtype) print(f"Loading vocab file '{path}', type '{vocabtype}'") added_tokens_path = path.parent / "added_tokens.json" + vocab: Vocab if vocabtype == "bpe": vocab = BpeVocab( path, added_tokens_path if added_tokens_path.exists() else None @@ -1428,6 +1331,7 @@ class VocabFactory: ) else: raise ValueError(f"Unsupported vocabulary type {vocabtype}") + # FIXME: Respect --vocab-dir? special_vocab = self._create_special_vocab( vocab, vocabtype, @@ -1436,18 +1340,17 @@ class VocabFactory: return vocab, special_vocab -def default_output_file(model_paths: list[Path], file_type: GGMLFileType) -> Path: +def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: namestr = { - GGMLFileType.AllF32: "f32", + GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", - GGMLFileType.MostlyQ8_0: "q8_0", + GGMLFileType.MostlyQ8_0:"q8_0", }[file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" if ret in model_paths: sys.stderr.write( f"Error: Default output path ({ret}) would overwrite the input. " - "Please explicitly specify a path using --outfile.\n" - ) + "Please explicitly specify a path using --outfile.\n") sys.exit(1) return ret @@ -1457,111 +1360,34 @@ def do_dump_model(model_plus: ModelPlus) -> None: print(f"model_plus.format = {model_plus.format!r}") print(f"model_plus.vocab = {model_plus.vocab!r}") for name, lazy_tensor in model_plus.model.items(): - print( - f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}" - ) + print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") -def get_argument_parser() -> ArgumentParser: +def main(args_in: list[str] | None = None) -> None: output_choices = ["f32", "f16"] if np.uint32(1) == np.uint32(1).newbyteorder("<"): # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") + vocab_types = ["spm", "bpe", "hfft"] + parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") + parser.add_argument("--awq-path", type=Path, help="Path to scale awq cache file", default=None) + parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") + parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") + parser.add_argument("--vocab-type", choices=vocab_types, help="The vocabulary format used to define the tokenizer model (default: spm)", default="spm") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") + parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY) + parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") + parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") - parser = argparse.ArgumentParser( - description="Convert a LLaMa model to a GGML compatible file" - ) - - parser.add_argument( - "model", - type=Path, - help="Directory containing the model file or the model file itself (*.pth, *.pt, *.bin)", - ) - - parser.add_argument( - "--awq-path", - type=Path, - help="Path to the Activation-aware Weight Quantization cache file", - default=None, - ) - - parser.add_argument( - "--dump", - action="store_true", - help="Display the model content without converting it", - ) - - parser.add_argument( - "--dump-single", - action="store_true", - help="Display the content of a single model file without conversion", - ) - - parser.add_argument( - "--vocab-only", - action="store_true", - help="Extract and output only the vocabulary", - ) - - parser.add_argument( - "--outtype", - choices=output_choices, - help="Output format - note: q8_0 may be very slow (default: f16 or f32 based on input)", - ) - - parser.add_argument( - "--vocab-dir", - type=Path, - help="Directory containing the tokenizer.model, if separate from the model file", - ) - - parser.add_argument( - "--vocab-type", - choices=["spm", "bpe", "hfft"], # hfft: Hugging Face Fast Tokenizer - default="spm", - help="The vocabulary format used to define the tokenizer model (default: spm)", - ) - - parser.add_argument( - "--pad-vocab", - action="store_true", - help="Add padding tokens when the model's vocabulary size exceeds the tokenizer metadata", - ) - - parser.add_argument( - "--outfile", - type=Path, - help="Specify the path for the output file (default is based on input)", - ) - - parser.add_argument( - "--ctx", type=int, help="Model training context (default is based on input)" - ) - - parser.add_argument( - "--concurrency", - type=int, - help=f"Concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", - default=DEFAULT_CONCURRENCY, - ) - - parser.add_argument( - "--big-endian", - action="store_true", - help="Indicate that the model is executed on a big-endian machine", - ) - - return parser - - -def main(argv: Optional[list[str]] = None) -> None: - parser = get_argument_parser() - args = parser.parse_args(argv) - + args = parser.parse_args(args_in) if args.awq_path: - sys.path.insert(1, str(Path(__file__).resolve().parent / "awq-py")) - from awq.apply_awq import add_scale_weights - + sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) + from awq.apply_awq import add_scale_weights # type: ignore[import-not-found] tmp_model_path = args.model / "weighted_model" if tmp_model_path.is_dir(): print(f"{tmp_model_path} exists as a weighted model.") @@ -1580,14 +1406,11 @@ def main(argv: Optional[list[str]] = None) -> None: if not args.vocab_only: model_plus = load_some_model(args.model) else: - model_plus = ModelPlus( - model={}, paths=[args.model / "dummy"], format="none", vocab=None - ) + model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) if args.dump: do_dump_model(model_plus) return - endianess = gguf.GGUFEndian.LITTLE if args.big_endian: endianess = gguf.GGUFEndian.BIG @@ -1595,12 +1418,10 @@ def main(argv: Optional[list[str]] = None) -> None: params = Params.load(model_plus) if params.n_ctx == -1: if args.ctx is None: - raise Exception( - "The model doesn't have a context size, and you didn't specify one with --ctx\n" - "Please specify one with --ctx:\n" - " - LLaMA v1: --ctx 2048\n" - " - LLaMA v2: --ctx 4096\n" - ) + raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" + "Please specify one with --ctx:\n" + " - LLaMA v1: --ctx 2048\n" + " - LLaMA v2: --ctx 4096\n") params.n_ctx = args.ctx if args.outtype: @@ -1621,42 +1442,30 @@ def main(argv: Optional[list[str]] = None) -> None: if not args.outfile: raise ValueError("need --outfile if using --vocab-only") outfile = args.outfile - OutputFile.write_vocab_only( - outfile, - params, - vocab, - special_vocab, - endianess=endianess, - pad_vocab=args.pad_vocab, - ) + OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, + endianess=endianess, pad_vocab=args.pad_vocab) print(f"Wrote {outfile}") return if model_plus.vocab is not None and args.vocab_dir is None: vocab = model_plus.vocab - model = model_plus.model - model = convert_model_names(model, params) - ftype = pick_output_type(model, args.outtype) - model = convert_to_output_type(model, ftype) - outfile = args.outfile or default_output_file(model_plus.paths, ftype) + print(f"Vocab info: {vocab}") + print(f"Special vocab info: {special_vocab}") + + model = model_plus.model + model = convert_model_names(model, params) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_outfile(model_plus.paths, ftype) params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all( - outfile, - ftype, - params, - model, - vocab, - special_vocab, - concurrency=args.concurrency, - endianess=endianess, - pad_vocab=args.pad_vocab, - ) + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) print(f"Wrote {outfile}") -if __name__ == "__main__": - main(sys.argv[1:]) # Exclude the first element (script name) from sys.argv +if __name__ == '__main__': + main() diff --git a/mypy.ini b/mypy.ini index 7215a05dd..e51910ca7 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,3 +4,4 @@ allow_untyped_calls = true allow_untyped_defs = true allow_incomplete_defs = true disable_error_code = import-untyped +warn_return_any = false