Skip to content

CKeibel/DecodNER

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DecodNER

A GPU-accelerated inference engine for Named Entity Recognition (NER) using causal decoder language models with a shared-prefix KV cache.

How it works

NER requires scoring many candidate spans against the same document context. DecodNER exploits this by encoding the context (instruction + entity list + document) once as a shared prefix, then scoring all candidate spans in a single batched pass — without re-encoding the prefix for each span.

┌─────────────────────────────────────────────────────┐
│  PREFIX  (encoded once, KV cache shared across all) │
│  [BOS] <instruction> <entity list> <document text>  │
└──────────────────────┬──────────────────────────────┘
                       │  shared KV cache
          ┌────────────┼────────────┐
          ▼            ▼            ▼
      "Apple"      "Steve Jobs"  "Cupertino"   ← spans (batched)
          │            │            │
          ▼            ▼            ▼
     logit["2"]   logit["1"]   logit["3"]      ← entity label scores
      Company       Person       Location

Each span attends to the full prefix via the shared KV cache plus its own tokens causally. The logit at the last span token position — restricted to the label token IDs — gives the entity classification score.

Requires Python ≥ 3.12, a CUDA GPU, and PyTorch ≥ 2.11.

Quick start

from transformers import AutoTokenizer
from decodner import UniversalDecoder, tokenize_spans, get_label_token_ids, get_entity_scores

# Load any causal decoder model from HuggingFace Hub
model     = UniversalDecoder.from_hub("meta-llama/Llama-3.2-1B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

# Build the shared prefix: instruction + numbered entity list + document
prompt = (
    "Identify the entity type of the marked span.\n"
    "Entity types: 1. Person  2. Company  3. Location  0. None\n\n"
    "Text: Apple was founded by Steve Jobs in Cupertino.\n"
    "Span: "
)
prefix_ids = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).input_ids.cuda()

# Tokenize all candidate spans at once (padding handled automatically)
span_ids, position_ids, attention_mask = tokenize_spans(
    ["Apple", "Steve Jobs", "Cupertino", "founded"],
    tokenizer,
    prefix_len=prefix_ids.shape[1],
)

# Encode prefix once, score all spans in a single batched pass
logits = model.score_spans(prefix_ids, span_ids, position_ids)
# → [4, max_span_len, vocab_size]

# Extract scores for label tokens at the last real token of each span
label_token_ids = get_label_token_ids(["0", "1", "2", "3"], tokenizer)
scores = get_entity_scores(logits, label_token_ids, attention_mask)
# → [4, 4]  (4 spans × 4 labels)

predicted_label_idx = scores.argmax(dim=-1)
labels = ["None", "Person", "Company", "Location"]
for span, idx in zip(["Apple", "Steve Jobs", "Cupertino", "founded"], predicted_label_idx):
    print(f"{span:15s}{labels[idx]}")
# Apple           → Company
# Steve Jobs      → Person
# Cupertino       → Location
# founded         → None

API reference

UniversalDecoder

UniversalDecoder.from_hub(repo_id, dtype=torch.float16)

Load a model directly from HuggingFace Hub. Downloads weights in safetensors format.

Supported architectures: Llama, Mistral, Mixtral, Qwen2.

model = UniversalDecoder.from_hub("meta-llama/Llama-3.2-1B", dtype=torch.float16)

model.score_spans(prefix_ids, span_ids, position_ids=None)

End-to-end convenience method. Encodes the prefix, populates the KV cache, and scores all spans in one call.

logits = model.score_spans(prefix_ids, span_ids, position_ids)
# prefix_ids:   [1, prefix_len]
# span_ids:     [batch, span_len]
# position_ids: [batch, span_len]  — inferred from prefix_len if omitted
# returns:      [batch, span_len, vocab_size]

model.prefill_prefix(prefix_ids, kv_cache) / model.forward_spans(span_ids, position_ids, kv_cache)

Low-level API for reusing a single encoded prefix across multiple forward_spans calls (e.g. scoring spans in mini-batches against a long document).

from decodner import DualBufferCache

cache = DualBufferCache(
    num_layers     = model.config.num_layers,
    num_kv_heads   = model.config.num_kv_heads,
    head_dim       = model.config.head_dim,
    max_prefix_len = 512,
    max_suffix_len = 32,
    batch_size     = 64,
    dtype          = torch.float16,
    device         = "cuda",
)

model.prefill_prefix(prefix_ids, cache)  # encode once

for batch in span_batches:
    span_ids, position_ids, mask = batch
    logits = model.forward_spans(span_ids, position_ids, cache)
    # process logits ...

Utilities

tokenize_spans(spans, tokenizer, prefix_len, device="cuda", add_special_tokens=False)

Tokenize a string or list of strings into padded tensors with correct position IDs.

span_ids, position_ids, attention_mask = tokenize_spans(
    ["Berlin", "Elon Musk", "Tesla"],
    tokenizer,
    prefix_len=prefix_ids.shape[1],
)

Returns (span_ids, position_ids, attention_mask), all [batch, max_span_len].

get_label_token_ids(labels, tokenizer)

Resolve label strings to single token IDs. Raises ValueError if any label tokenizes to more than one token.

label_token_ids = get_label_token_ids(["0", "1", "2", "3"], tokenizer)
# → [29900, 29896, 29906, 29941]  (example; depends on tokenizer)

Labels must be single tokens. Numeric labels ("1", "2", ...) are always single tokens and are recommended for this reason.

get_entity_scores(logits, label_token_ids, attention_mask=None)

Extract next-token scores for the label tokens at the last real token position of each span.

scores = get_entity_scores(logits, label_token_ids, attention_mask)
# → [batch, num_labels]

probabilities    = scores.softmax(dim=-1)
predicted_labels = scores.argmax(dim=-1)

Prompt design

The prompt format is flexible but the following structure hopefully works well:

<instruction explaining the task>
Entity types: 1. <Type1>  2. <Type2>  3. <Type3>  0. None

Text: <the input document>
Span:

The document and span together form a natural continuation — the model sees the full context and is asked to predict which entity number follows the span.

Tips:

  • Keep label strings as single tokens (numbers work best)
  • Include a "None" / "0" label for non-entity spans
  • The instruction should match the style the model was instruction-tuned on
  • Use add_special_tokens=True when tokenizing the prefix so the BOS token is included

Architecture

File Description
model.py UniversalDecoder — top-level model, score_spans entry point
layer.py UniversalDecoderLayer — attention + MLP, prefix/suffix modes
kv_cache.py DualBufferCache — separate prefix (shared) and suffix (per-span) KV buffers
kernel.py Triton kernel — fused prefix+suffix attention with online softmax
config.py ModelConfig — parsed from HuggingFace config.json
weights.py WeightLoader — safetensors loading with architecture-specific key mapping
rope.py RoPE frequency precomputation and application (NTK and linear scaling)
ops.py RMSNorm, LayerNorm
utils.py tokenize_spans, get_label_token_ids, get_entity_scores

About

A Decoder based Named Entity Recognition (NER) Famework

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages