Skip to content

[Bug/Question] Numerical inconsistency in model representations among inferences #7

@yusowa0716

Description

@yusowa0716

Description:
I am observing a discrepancy in the output representations (token embeddings) when inferring the exact same sequence under two different execution contexts on NVIDIA A100.

The Problem:
I am comparing the representations of a specific sequence S in two scenarios:

  • Scenario A (Loop): I run a loop to infer $N$ sequences sequentially. I then extract the representation for sequence S.
  • Scenario B (Single Run): I run the inference script for only sequence $S$.

In both scenarios, I strictly ensure that the batch size is 1. Despite the identical input and model weights, the resulting tensors are not identical.

Comparison Results:

Image

Minimal Code:

def load_e1_model(model_name):

    try:
        model = E1ForMaskedLM.from_pretrained(E1_CONFIG[model_name])
        model.eval()
    except:
        raise ValueError(f"Model {model_name} not found")
  
    if torch.cuda.is_available():
        model = model.cuda()

  return model
def compute_E1_embeddings(
    model,
    labels,
    sequences,
    save_dir,
    max_batch_tokens=16384,
):
    predictor = E1Predictor(
        model=model,
        max_batch_tokens=max_batch_tokens,
        fields_to_save=["token_embeddings"],
        use_cache=False,
    )

    embeddings = {}

    for prediction in predictor.predict(
        sequences=sequences, sequence_ids=labels, context_seqs=None
    ):
        if _check_files_exist(save_dir, labels):
            continue
        label = prediction["id"]
        token_embeddings = prediction["token_embeddings"]  # (Sequence Length, Embedding Dim)
        embeddings[label] = token_embeddings.cpu().clone()
        save_path = os.path.join(save_dir, label + ".pt")
        torch.save(embeddings[label], save_path)
    return embeddings

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions