Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions minicheck/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def fact_check(self, doc, claim):

class LLMCheck:

def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=None, enable_prefix_caching=False, max_model_len=None):
def __init__(self, model_id, peft_path=None, max_lora_rank=16, operating_mode="bespoke", think_end_token="</think>", extra_chat_template_kwargs=None, tensor_parallel_size=1, max_tokens=1, cache_dir=None, enable_prefix_caching=False, max_model_len=None):
from vllm import LLM, SamplingParams

import logging
Expand All @@ -288,11 +288,23 @@ def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=Non
if model_id == 'Bespoke-MiniCheck-7B':
self.model_id = 'bespokelabs/Bespoke-MiniCheck-7B'
self.operating_mode="bespoke"

self.extra_chat_template_kwargs = {}
elif model_id == 'Granite-Guardian-3.3-8B':
self.model_id = 'ibm-granite/granite-guardian-3.3-8b'
self.operating_mode="gg_hybrid"

self.extra_chat_template_kwargs = {
'guardian_config': {"criteria_id": "groundedness"},
'think': True
}
else:
raise ValueError("model_id must be 'Bespoke-MiniCheck-7B'")
self.model_id = model_id
self.operating_mode=operating_mode

self.extra_chat_template_kwargs = extra_chat_template_kwargs if extra_chat_template_kwargs is not None else {}

self.peft_path = peft_path

self.tensor_parallel_size = tensor_parallel_size
self.max_tokens = max_tokens
Expand Down Expand Up @@ -329,7 +341,9 @@ def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=Non
tensor_parallel_size=self.tensor_parallel_size,
seed=2024,
max_model_len=self.max_model_len, # need to be adjusted based on the GPU memory available
enable_prefix_caching=self.enable_prefix_caching
enable_prefix_caching=self.enable_prefix_caching,
max_lora_rank=max_lora_rank,
enable_lora=True if self.peft_path is not None else False
)

self.tokenizer = self.llm.get_tokenizer()
Expand All @@ -338,9 +352,13 @@ def __init__(self, model_id, tensor_parallel_size=1, max_tokens=1, cache_dir=Non
self.tokenizer.eos_token_id,
]
converted_token = self.tokenizer.convert_tokens_to_ids("<|eot_id|>")

if converted_token is not None:
terminators.append(converted_token)

if operating_mode == "thinking":
self.thinking_end_token=self.tokenizer.convert_tokens_to_ids(think_end_token)

self.sampling_params = SamplingParams(
temperature=0,
max_tokens=self.max_tokens,
Expand Down Expand Up @@ -368,12 +386,19 @@ def apply_chat_template(self, doc, claim):
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
text = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
text = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False, **self.extra_chat_template_kwargs)
elif self.operating_mode=="gg_hybrid":
documents = [{'doc_id':'0', 'text': doc}]
messages = [{"role": "assistant", "content": claim}]
guardian_config = {"criteria_id": "groundedness"}
text = self.tokenizer.apply_chat_template(messages, guardian_config = guardian_config, documents=documents, think=True, tokenize=False, add_generation_prompt=True)
text = self.tokenizer.apply_chat_template(messages, documents=documents, add_generation_prompt=True, tokenize=False, **self.extra_chat_template_kwargs)
elif self.operating_mode=="thinking":
user_prompt = self.user_prompt.replace("[DOCUMENT]", doc).replace("[CLAIM]", claim)
message = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": user_prompt},
]
text = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False, **self.extra_chat_template_kwargs)

return text


Expand All @@ -398,6 +423,38 @@ def get_support_prob_hybrid_gg(self, response, marker="score"):
print("Error:", e)
support_prob = random.random()
return support_prob

def get_support_prob_thinking(self, response):
"""probs from vllm inference"""
import math
support_prob = 0
start_response_index = -1

completion = response.outputs[0]

try:
if self.thinking_end_token in completion.token_ids:
max_token_index = len(completion.token_ids) - 1
thinking_token_index = completion.token_ids.index(self.thinking_end_token) + 1

decoded_token = next(iter(completion.logprobs[thinking_token_index].values())).decoded_token

while("\n" in decoded_token and max_token_index):
thinking_token_index += 1
decoded_token = next(iter(completion.logprobs[thinking_token_index].values())).decoded_token

if thinking_token_index <= max_token_index:
start_response_index = thinking_token_index

for token_prob in completion.logprobs[start_response_index].values():
decoded_token = token_prob.decoded_token
if decoded_token.lower() == 'yes':
support_prob += math.exp(token_prob.logprob)
except Exception as e:
print("Error:", e)
support_prob = random.random()

return support_prob


def get_all_chunks_per_doc(self, doc, claim):
Expand Down Expand Up @@ -464,11 +521,24 @@ def score(self, docs: List[str], claims: List[str], chunk_size=None) -> List[flo
all_prompts.extend(prompts)
doc_claim_indices.extend([index] * len(prompts))

responses = self.llm.generate(all_prompts, self.sampling_params)
if self.peft_path is not None:
from vllm.lora.request import LoRARequest

responses = self.llm.generate(
all_prompts,
self.sampling_params,
lora_request=LoRARequest("lora_adapter", 1, self.peft_path) if self.peft_path else None)
else:
responses = self.llm.generate(
all_prompts,
self.sampling_params)

if self.operating_mode=="bespoke":
probs_per_chunk_sentence = [self.get_support_prob(responses[idx]) for idx in range(len(responses))]
elif self.operating_mode=="gg_hybrid":
probs_per_chunk_sentence = [self.get_support_prob_hybrid_gg(responses[idx]) for idx in range(len(responses))]
elif self.operating_mode=="thinking":
probs_per_chunk_sentence = [self.get_support_prob_thinking(responses[idx]) for idx in range(len(responses))]

result_dict = {}
for index, prob_per_chunk_sentence in zip(doc_claim_indices, probs_per_chunk_sentence):
Expand Down Expand Up @@ -504,4 +574,4 @@ def score(self, docs: List[str], claims: List[str], chunk_size=None) -> List[flo
return pred_label, max_support_prob, used_chunk, support_prob_per_chunk

def split_into_sentences(self, text: str) -> List[str]:
return nltk.sent_tokenize(text)
return nltk.sent_tokenize(text)
64 changes: 60 additions & 4 deletions minicheck/minicheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class MiniCheck:
def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_size=16, cache_dir=None, tensor_parallel_size=1, max_tokens=1, enable_prefix_caching=False) -> None:
def __init__(self, model_name='Bespoke-MiniCheck-7B', peft_path=None, max_lora_rank=16, operating_mode="bespoke", think_end_token=None, extra_chat_template_kwargs=None, max_model_len=None, batch_size=16, cache_dir=None, tensor_parallel_size=1, max_tokens=1, enable_prefix_caching=False, bypass_model_check=False) -> None:

'''
Parameters:
Expand All @@ -19,6 +19,33 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_
- 'Bespoke-MiniCheck-7B'
- 'Granite-Guardian-3.3-8B'
Note: 'Bespoke-MiniCheck-7B' is the most performant fact-checking model in the MiniCheck series.

peft_path : str, optional (default=None)
Path to the LLM PEFT adapter
- 'Bespoke-MiniCheck-7B'
peft_path: None
- 'Granite-Guardian-3.3-8B'
peft_path: None

max_lora_rank : int, optional (default=16)
Maximum LoRA Adapter Rank to load

operating_mode : str, optional (default='bespoke')
LLM model support probability operating mode
Preset models use their corresponding operating mode, i.e:
- 'Bespoke-MiniCheck-7B'
Operating Mode: 'bespoke'
- 'Granite-Guardian-3.3-8B'
Operating Mode: 'gg_hybrid'
Extra operating mode:
- 'thinking' uses the first logprobs after the thinking delimiter as support probability

think_end_token : str, optional (default=None)
Token used to represent the end of the thinking traces of LLM models

extra_chat_template_kwargs : dict, optional (default=None)
Extra kwargs to forward to the chat template
Preset models use their corresponding chat template kwargs

max_model_len : int or None, optional (default=None)
The maximum input length for the model. If None, we use the following default values.
Expand Down Expand Up @@ -57,6 +84,9 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_
Whether to enable prefix caching for 'Bespoke-MiniCheck-7B'. This can improve performance
when using the same document chunk to fact-check different claims.

bypass_model_check: bool, optional (default=False)
Allows to bypass the model check to run the benchmark on different models with various configuration

Note:
(1) MiniCheck-Flan-T5-Large (770M) is the best fack-checking model with size < 1B and reaches GPT-4 performance.
(2) Bespoke-MiniCheck-7B is the most performant fact-checking model in the MiniCheck series AND
Expand All @@ -72,18 +102,24 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_
future grounded fact-checking with much higher throughput and much lower latency.
'''

assert model_name in ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B'], \
"model_name must be one of ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B']"

if not bypass_model_check:
assert model_name in ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B'], \
"model_name must be one of ['roberta-large', 'deberta-v3-large', 'flan-t5-large', 'Bespoke-MiniCheck-7B', 'Granite-Guardian-3.3-8B']"

if model_name in ['roberta-large', 'deberta-v3-large', 'flan-t5-large']:
if operating_mode != 'operating_mode' or extra_chat_template_kwargs is not None or peft_path is not None or think_end_token is not None:
print(f"Forcing default preset configuration for model {model_name}")

self.model = Inferencer(
model_name=model_name,
batch_size=batch_size,
max_model_len=max_model_len,
cache_dir=cache_dir
)
elif model_name == 'Bespoke-MiniCheck-7B':
if operating_mode != 'bespoke' or extra_chat_template_kwargs is not None or peft_path is not None or think_end_token is not None:
print("Forcing default preset configuration for model Bespoke-MiniCheck-7B")

self.model = LLMCheck(
model_id=model_name,
tensor_parallel_size=tensor_parallel_size,
Expand All @@ -93,6 +129,9 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_
max_model_len=max_model_len
)
elif model_name == 'Granite-Guardian-3.3-8B':
if operating_mode != 'gg_hybrid' or extra_chat_template_kwargs is not None or peft_path is not None or think_end_token is not None:
print("Forcing default preset configuration for model Granite Guardian 3.3")

if not max_tokens or max_tokens<2048:
print("For Granite Guardian 3.3 - fixing the max_tokens to be 2048")
max_tokens=2048
Expand All @@ -105,6 +144,23 @@ def __init__(self, model_name='Bespoke-MiniCheck-7B', max_model_len=None, batch_
enable_prefix_caching=enable_prefix_caching,
max_model_len=max_model_len
)
else:
if operating_mode == "thinking":
assert think_end_token is not None, "'thinking' operating mode requires to specify a 'think_end_token'"

self.model = LLMCheck(
model_id=model_name,
peft_path=peft_path,
max_lora_rank=max_lora_rank,
operating_mode=operating_mode,
think_end_token=think_end_token,
extra_chat_template_kwargs=extra_chat_template_kwargs,
tensor_parallel_size=tensor_parallel_size,
max_tokens=max_tokens,
cache_dir=cache_dir,
enable_prefix_caching=enable_prefix_caching,
max_model_len=max_model_len
)


def score(self, docs: List[str], claims: List[str], chunk_size=None) -> List[float]:
Expand Down