A GPU-accelerated inference engine for Named Entity Recognition (NER) using causal decoder language models with a shared-prefix KV cache.
NER requires scoring many candidate spans against the same document context. DecodNER exploits this by encoding the context (instruction + entity list + document) once as a shared prefix, then scoring all candidate spans in a single batched pass — without re-encoding the prefix for each span.
┌─────────────────────────────────────────────────────┐
│ PREFIX (encoded once, KV cache shared across all) │
│ [BOS] <instruction> <entity list> <document text> │
└──────────────────────┬──────────────────────────────┘
│ shared KV cache
┌────────────┼────────────┐
▼ ▼ ▼
"Apple" "Steve Jobs" "Cupertino" ← spans (batched)
│ │ │
▼ ▼ ▼
logit["2"] logit["1"] logit["3"] ← entity label scores
Company Person Location
Each span attends to the full prefix via the shared KV cache plus its own tokens causally. The logit at the last span token position — restricted to the label token IDs — gives the entity classification score.
Requires Python ≥ 3.12, a CUDA GPU, and PyTorch ≥ 2.11.
from transformers import AutoTokenizer
from decodner import UniversalDecoder, tokenize_spans, get_label_token_ids, get_entity_scores
# Load any causal decoder model from HuggingFace Hub
model = UniversalDecoder.from_hub("meta-llama/Llama-3.2-1B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
# Build the shared prefix: instruction + numbered entity list + document
prompt = (
"Identify the entity type of the marked span.\n"
"Entity types: 1. Person 2. Company 3. Location 0. None\n\n"
"Text: Apple was founded by Steve Jobs in Cupertino.\n"
"Span: "
)
prefix_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.cuda()
# Tokenize all candidate spans at once (padding handled automatically)
span_ids, position_ids, attention_mask = tokenize_spans(
["Apple", "Steve Jobs", "Cupertino", "founded"],
tokenizer,
prefix_len=prefix_ids.shape[1],
)
# Encode prefix once, score all spans in a single batched pass
logits = model.score_spans(prefix_ids, span_ids, position_ids)
# → [4, max_span_len, vocab_size]
# Extract scores for label tokens at the last real token of each span
label_token_ids = get_label_token_ids(["0", "1", "2", "3"], tokenizer)
scores = get_entity_scores(logits, label_token_ids, attention_mask)
# → [4, 4] (4 spans × 4 labels)
predicted_label_idx = scores.argmax(dim=-1)
labels = ["None", "Person", "Company", "Location"]
for span, idx in zip(["Apple", "Steve Jobs", "Cupertino", "founded"], predicted_label_idx):
print(f"{span:15s} → {labels[idx]}")
# Apple → Company
# Steve Jobs → Person
# Cupertino → Location
# founded → NoneLoad a model directly from HuggingFace Hub. Downloads weights in safetensors format.
Supported architectures: Llama, Mistral, Mixtral, Qwen2.
model = UniversalDecoder.from_hub("meta-llama/Llama-3.2-1B", dtype=torch.float16)End-to-end convenience method. Encodes the prefix, populates the KV cache, and scores all spans in one call.
logits = model.score_spans(prefix_ids, span_ids, position_ids)
# prefix_ids: [1, prefix_len]
# span_ids: [batch, span_len]
# position_ids: [batch, span_len] — inferred from prefix_len if omitted
# returns: [batch, span_len, vocab_size]Low-level API for reusing a single encoded prefix across multiple forward_spans calls (e.g. scoring spans in mini-batches against a long document).
from decodner import DualBufferCache
cache = DualBufferCache(
num_layers = model.config.num_layers,
num_kv_heads = model.config.num_kv_heads,
head_dim = model.config.head_dim,
max_prefix_len = 512,
max_suffix_len = 32,
batch_size = 64,
dtype = torch.float16,
device = "cuda",
)
model.prefill_prefix(prefix_ids, cache) # encode once
for batch in span_batches:
span_ids, position_ids, mask = batch
logits = model.forward_spans(span_ids, position_ids, cache)
# process logits ...Tokenize a string or list of strings into padded tensors with correct position IDs.
span_ids, position_ids, attention_mask = tokenize_spans(
["Berlin", "Elon Musk", "Tesla"],
tokenizer,
prefix_len=prefix_ids.shape[1],
)Returns (span_ids, position_ids, attention_mask), all [batch, max_span_len].
Resolve label strings to single token IDs. Raises ValueError if any label tokenizes to more than one token.
label_token_ids = get_label_token_ids(["0", "1", "2", "3"], tokenizer)
# → [29900, 29896, 29906, 29941] (example; depends on tokenizer)Labels must be single tokens. Numeric labels ("1", "2", ...) are always single tokens and are recommended for this reason.
Extract next-token scores for the label tokens at the last real token position of each span.
scores = get_entity_scores(logits, label_token_ids, attention_mask)
# → [batch, num_labels]
probabilities = scores.softmax(dim=-1)
predicted_labels = scores.argmax(dim=-1)The prompt format is flexible but the following structure hopefully works well:
<instruction explaining the task>
Entity types: 1. <Type1> 2. <Type2> 3. <Type3> 0. None
Text: <the input document>
Span:
The document and span together form a natural continuation — the model sees the full context and is asked to predict which entity number follows the span.
Tips:
- Keep label strings as single tokens (numbers work best)
- Include a "None" / "0" label for non-entity spans
- The instruction should match the style the model was instruction-tuned on
- Use
add_special_tokens=Truewhen tokenizing the prefix so the BOS token is included
| File | Description |
|---|---|
model.py |
UniversalDecoder — top-level model, score_spans entry point |
layer.py |
UniversalDecoderLayer — attention + MLP, prefix/suffix modes |
kv_cache.py |
DualBufferCache — separate prefix (shared) and suffix (per-span) KV buffers |
kernel.py |
Triton kernel — fused prefix+suffix attention with online softmax |
config.py |
ModelConfig — parsed from HuggingFace config.json |
weights.py |
WeightLoader — safetensors loading with architecture-specific key mapping |
rope.py |
RoPE frequency precomputation and application (NTK and linear scaling) |
ops.py |
RMSNorm, LayerNorm |
utils.py |
tokenize_spans, get_label_token_ids, get_entity_scores |