diff --git a/minicheck/inference.py b/minicheck/inference.py index 074d1b7..8c3c046 100644 --- a/minicheck/inference.py +++ b/minicheck/inference.py @@ -271,7 +271,7 @@ def fact_check(self, doc, claim): class LLMCheck: - def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=None, enable_prefix_caching=False, max_model_len=None): + def __init__(self, model_id, peft_path=None, max_lora_rank=16, operating_mode="bespoke", think_end_token="", extra_chat_template_kwargs=None, tensor_parallel_size=1, max_tokens=1, cache_dir=None, enable_prefix_caching=False, max_model_len=None): from vllm import LLM, SamplingParams import logging @@ -288,11 +288,23 @@ def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=Non if model_id == 'Bespoke-MiniCheck-7B': self.model_id = 'bespokelabs/Bespoke-MiniCheck-7B' self.operating_mode="bespoke" + + self.extra_chat_template_kwargs = {} elif model_id == 'Granite-Guardian-3.3-8B': self.model_id = 'ibm-granite/granite-guardian-3.3-8b' self.operating_mode="gg_hybrid" + + self.extra_chat_template_kwargs = { + 'guardian_config': {"criteria_id": "groundedness"}, + 'think': True + } else: - raise ValueError("model_id must be 'Bespoke-MiniCheck-7B'") + self.model_id = model_id + self.operating_mode=operating_mode + + self.extra_chat_template_kwargs = extra_chat_template_kwargs if extra_chat_template_kwargs is not None else {} + + self.peft_path = peft_path self.tensor_parallel_size = tensor_parallel_size self.max_tokens = max_tokens @@ -329,7 +341,9 @@ def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=Non tensor_parallel_size=self.tensor_parallel_size, seed=2024, max_model_len=self.max_model_len, # need to be adjusted based on the GPU memory available - enable_prefix_caching=self.enable_prefix_caching + enable_prefix_caching=self.enable_prefix_caching, + max_lora_rank=max_lora_rank, + enable_lora=True if self.peft_path is not None else False ) self.tokenizer = self.llm.get_tokenizer() @@ -338,9 +352,13 @@ def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=Non self.tokenizer.eos_token_id, ] converted_token = self.tokenizer.convert_tokens_to_ids("<|eot_id|>") + if converted_token is not None: terminators.append(converted_token) + if operating_mode == "thinking": + self.thinking_end_token=self.tokenizer.convert_tokens_to_ids(think_end_token) + self.sampling_params = SamplingParams( temperature=0, max_tokens=self.max_tokens, @@ -368,12 +386,19 @@ def apply_chat_template(self, doc, claim): {"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_prompt}, ] - text = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False) + text = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False, **self.extra_chat_template_kwargs) elif self.operating_mode=="gg_hybrid": documents = [{'doc_id':'0', 'text': doc}] messages = [{"role": "assistant", "content": claim}] - guardian_config = {"criteria_id": "groundedness"} - text = self.tokenizer.apply_chat_template(messages, guardian_config = guardian_config, documents=documents, think=True, tokenize=False, add_generation_prompt=True) + text = self.tokenizer.apply_chat_template(messages, documents=documents, add_generation_prompt=True, tokenize=False, **self.extra_chat_template_kwargs) + elif self.operating_mode=="thinking": + user_prompt = self.user_prompt.replace("[DOCUMENT]", doc).replace("[CLAIM]", claim) + message = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": user_prompt}, + ] + text = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False, **self.extra_chat_template_kwargs) + return text @@ -398,6 +423,38 @@ def get_support_prob_hybrid_gg(self, response, marker="score"): print("Error:", e) support_prob = random.random() return support_prob + + def get_support_prob_thinking(self, response): + """probs from vllm inference""" + import math + support_prob = 0 + start_response_index = -1 + + completion = response.outputs[0] + + try: + if self.thinking_end_token in completion.token_ids: + max_token_index = len(completion.token_ids) - 1 + thinking_token_index = completion.token_ids.index(self.thinking_end_token) + 1 + + decoded_token = next(iter(completion.logprobs[thinking_token_index].values())).decoded_token + + while("\n" in decoded_token and max_token_index): + thinking_token_index += 1 + decoded_token = next(iter(completion.logprobs[thinking_token_index].values())).decoded_token + + if thinking_token_index <= max_token_index: + start_response_index = thinking_token_index + + for token_prob in completion.logprobs[start_response_index].values(): + decoded_token = token_prob.decoded_token + if decoded_token.lower() == 'yes': + support_prob += math.exp(token_prob.logprob) + except Exception as e: + print("Error:", e) + support_prob = random.random() + + return support_prob def get_all_chunks_per_doc(self, doc, claim): @@ -464,11 +521,24 @@ def score(self, docs: List[str], claims: List[str], chunk_size=None) -> List[flo all_prompts.extend(prompts) doc_claim_indices.extend([index] * len(prompts)) - responses = self.llm.generate(all_prompts, self.sampling_params) + if self.peft_path is not None: + from vllm.lora.request import LoRARequest + + responses = self.llm.generate( + all_prompts, + self.sampling_params, + lora_request=LoRARequest("lora_adapter", 1, self.peft_path) if self.peft_path else None) + else: + responses = self.llm.generate( + all_prompts, + self.sampling_params) + if self.operating_mode=="bespoke": probs_per_chunk_sentence = [self.get_support_prob(responses[idx]) for idx in range(len(responses))] elif self.operating_mode=="gg_hybrid": probs_per_chunk_sentence = [self.get_support_prob_hybrid_gg(responses[idx]) for idx in range(len(responses))] + elif self.operating_mode=="thinking": + probs_per_chunk_sentence = [self.get_support_prob_thinking(responses[idx]) for idx in range(len(responses))] result_dict = {} for index, prob_per_chunk_sentence in zip(doc_claim_indices, probs_per_chunk_sentence): @@ -504,4 +574,4 @@ def score(self, docs: List[str], claims: List[str], chunk_size=None) -> List[flo return pred_label, max_support_prob, used_chunk, support_prob_per_chunk def split_into_sentences(self, text: str) -> List[str]: - return nltk.sent_tokenize(text) \ No newline at end of file + return nltk.sent_tokenize(text) diff --git a/minicheck/minicheck.py b/minicheck/minicheck.py index a163ec6..9140123 100644 --- a/minicheck/minicheck.py +++ b/minicheck/minicheck.py @@ -6,7 +6,7 @@ class MiniCheck: - def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_size=16, cache_dir=None, tensor_parallel_size=1, max_tokens=1, enable_prefix_caching=False) -> None: + def __init__(self, model_name='Bespoke-MiniCheck-7B', peft_path=None, max_lora_rank=16, operating_mode="bespoke", think_end_token=None, extra_chat_template_kwargs=None, max_model_len=None, batch_size=16, cache_dir=None, tensor_parallel_size=1, max_tokens=1, enable_prefix_caching=False, bypass_model_check=False) -> None: ''' Parameters: @@ -19,6 +19,33 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_ - 'Bespoke-MiniCheck-7B' - 'Granite-Guardian-3.3-8B' Note: 'Bespoke-MiniCheck-7B' is the most performant fact-checking model in the MiniCheck series. + + peft_path : str, optional (default=None) + Path to the LLM PEFT adapter + - 'Bespoke-MiniCheck-7B' + peft_path: None + - 'Granite-Guardian-3.3-8B' + peft_path: None + + max_lora_rank : int, optional (default=16) + Maximum LoRA Adapter Rank to load + + operating_mode : str, optional (default='bespoke') + LLM model support probability operating mode + Preset models use their corresponding operating mode, i.e: + - 'Bespoke-MiniCheck-7B' + Operating Mode: 'bespoke' + - 'Granite-Guardian-3.3-8B' + Operating Mode: 'gg_hybrid' + Extra operating mode: + - 'thinking' uses the first logprobs after the thinking delimiter as support probability + + think_end_token : str, optional (default=None) + Token used to represent the end of the thinking traces of LLM models + + extra_chat_template_kwargs : dict, optional (default=None) + Extra kwargs to forward to the chat template + Preset models use their corresponding chat template kwargs max_model_len : int or None, optional (default=None) The maximum input length for the model. If None, we use the following default values. @@ -57,6 +84,9 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_ Whether to enable prefix caching for 'Bespoke-MiniCheck-7B'. This can improve performance when using the same document chunk to fact-check different claims. + bypass_model_check: bool, optional (default=False) + Allows to bypass the model check to run the benchmark on different models with various configuration + Note: (1) MiniCheck-Flan-T5-Large (770M) is the best fack-checking model with size < 1B and reaches GPT-4 performance. (2) Bespoke-MiniCheck-7B is the most performant fact-checking model in the MiniCheck series AND @@ -72,11 +102,14 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_ future grounded fact-checking with much higher throughput and much lower latency. ''' - assert model_name in ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B'], \ - "model_name must be one of ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B']" - + if not bypass_model_check: + assert model_name in ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B'], \ + "model_name must be one of ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B']" if model_name in ['roberta-large', 'deberta-v3-large', 'flan-t5-large']: + if operating_mode != 'operating_mode' or extra_chat_template_kwargs is not None or peft_path is not None or think_end_token is not None: + print(f"Forcing default preset configuration for model {model_name}") + self.model = Inferencer( model_name=model_name, batch_size=batch_size, @@ -84,6 +117,9 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_ cache_dir=cache_dir ) elif model_name == 'Bespoke-MiniCheck-7B': + if operating_mode != 'bespoke' or extra_chat_template_kwargs is not None or peft_path is not None or think_end_token is not None: + print("Forcing default preset configuration for model Bespoke-MiniCheck-7B") + self.model = LLMCheck( model_id=model_name, tensor_parallel_size=tensor_parallel_size, @@ -93,6 +129,9 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_ max_model_len=max_model_len ) elif model_name == 'Granite-Guardian-3.3-8B': + if operating_mode != 'gg_hybrid' or extra_chat_template_kwargs is not None or peft_path is not None or think_end_token is not None: + print("Forcing default preset configuration for model Granite Guardian 3.3") + if not max_tokens or max_tokens<2048: print("For Granite Guardian 3.3 - fixing the max_tokens to be 2048") max_tokens=2048 @@ -105,6 +144,23 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_ enable_prefix_caching=enable_prefix_caching, max_model_len=max_model_len ) + else: + if operating_mode == "thinking": + assert think_end_token is not None, "'thinking' operating mode requires to specify a 'think_end_token'" + + self.model = LLMCheck( + model_id=model_name, + peft_path=peft_path, + max_lora_rank=max_lora_rank, + operating_mode=operating_mode, + think_end_token=think_end_token, + extra_chat_template_kwargs=extra_chat_template_kwargs, + tensor_parallel_size=tensor_parallel_size, + max_tokens=max_tokens, + cache_dir=cache_dir, + enable_prefix_caching=enable_prefix_caching, + max_model_len=max_model_len + ) def score(self, docs: List[str], claims: List[str], chunk_size=None) -> List[float]: