|
23 | 23 | ) |
24 | 24 | from transformers import ( |
25 | 25 | AutoConfig, |
26 | | - AutoModelForCausalLM, |
27 | 26 | AutoTokenizer, |
28 | 27 | BitsAndBytesConfig, |
29 | 28 | ) |
@@ -64,11 +63,11 @@ def load_pretrained_model(model_path: str, |
64 | 63 | kwargs['torch_dtype'] = torch.float16 # type: ignore |
65 | 64 |
|
66 | 65 | if 'lita' not in model_name.lower(): |
67 | | - warnings.warn("this function is for loading LITA models") |
| 66 | + warnings.warn("this function is for loading LITA models", stacklevel=2) |
68 | 67 | if 'lora' in model_name.lower(): |
69 | | - warnings.warn("lora is currently not supported for LITA") |
| 68 | + warnings.warn("lora is currently not supported for LITA", stacklevel=2) |
70 | 69 | if 'mpt' in model_name.lower(): |
71 | | - warnings.warn("mpt is currently not supported for LITA") |
| 70 | + warnings.warn("mpt is currently not supported for LITA", stacklevel=2) |
72 | 71 |
|
73 | 72 | if model_base is not None: |
74 | 73 | print('Loading LITA from base model...') |
@@ -107,26 +106,26 @@ def load_pretrained_model(model_path: str, |
107 | 106 | assert num_new_tokens == 0, "time tokens should already be in the tokenizer for full finetune model" |
108 | 107 |
|
109 | 108 | if num_new_tokens > 0: |
110 | | - warnings.warn("looking for weights in mm_projector.bin") |
| 109 | + warnings.warn("looking for weights in mm_projector.bin", stacklevel=2) |
111 | 110 | assert num_new_tokens == num_time_tokens |
112 | 111 | model.resize_token_embeddings(len(tokenizer)) |
113 | 112 | input_embeddings = model.get_input_embeddings().weight.data |
114 | 113 | output_embeddings = model.get_output_embeddings().weight.data |
115 | 114 | weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') |
116 | 115 | assert 'model.embed_tokens.weight' in weights and 'lm_head.weight' in weights |
117 | | - |
| 116 | + |
118 | 117 | dtype = input_embeddings.dtype |
119 | 118 | device = input_embeddings.device |
120 | | - |
| 119 | + |
121 | 120 | tokenizer_time_token_ids = tokenizer.convert_tokens_to_ids(time_tokens) |
122 | 121 | time_token_ids = getattr(model.config, 'time_token_ids', tokenizer_time_token_ids) |
123 | 122 | input_embeddings[tokenizer_time_token_ids] = weights['model.embed_tokens.weight'][time_token_ids].to(dtype).to(device) |
124 | 123 | output_embeddings[tokenizer_time_token_ids] = weights['lm_head.weight'][time_token_ids].to(dtype).to(device) |
125 | | - |
| 124 | + |
126 | 125 | if hasattr(model.config, "max_sequence_length"): |
127 | 126 | context_len = model.config.max_sequence_length |
128 | 127 | else: |
129 | | - context_len = 2048 |
| 128 | + context_len = 2048 |
130 | 129 | return tokenizer, model, image_processor, context_len |
131 | 130 |
|
132 | 131 |
|
|
0 commit comments