From 6227ef3c929daab2f588b0adc8022d5c7fc215b1 Mon Sep 17 00:00:00 2001 From: Jonathan Olsson Date: Tue, 19 May 2026 11:03:43 +0200 Subject: [PATCH 1/3] Kronos, Deepcell, Eva implementation --- celldownload.py | 48 ++ src/data/DeepCellData.py | 86 +++ src/data/EvaData.py | 69 +++ src/data/KRONOSData.py | 69 +++ src/models/DeepCellModel.py | 292 ++++++++++ src/models/EvaModel.py | 427 +++++++++++++++ src/models/KRONOSModel.py | 506 ++++++++++++++++++ src/models/__init__.py | 52 ++ src/run/CellContrastEmbed.py | 143 +++-- src/utils/deepcell_kit/__init__.py | 0 src/utils/deepcell_kit/config.py | 115 ++++ .../deepcell_kit/config/channel_mapping.yaml | 83 +++ .../deepcell_kit/config/core_celltypes.yaml | 49 ++ .../deepcell_kit/config/master_channels.yaml | 177 ++++++ .../tissue_celltype_mapping_merged.yaml | 420 +++++++++++++++ src/utils/deepcell_kit/image_funcs.py | 455 ++++++++++++++++ src/utils/deepcell_kit/utils.py | 105 ++++ src/utils/download_utils/__init__.py | 0 src/utils/download_utils/dct_download.py | 195 +++++++ src/utils/download_utils/kr_download.py | 45 ++ src/utils/eva_kit/__init__.py | 0 src/utils/eva_kit/config.yaml | 27 + src/utils/eva_kit/constant.py | 209 ++++++++ src/utils/eva_kit/global_properties.py | 24 + src/utils/eva_kit/image_funcs.py | 263 +++++++++ src/utils/eva_kit/layers.py | 286 ++++++++++ src/utils/eva_kit/masking.py | 177 ++++++ src/utils/eva_kit/pos_embed.py | 248 +++++++++ src/utils/kronos_kit/__init__.py | 0 src/utils/kronos_kit/attention.py | 90 ++++ src/utils/kronos_kit/block.py | 265 +++++++++ src/utils/kronos_kit/dino_head.py | 58 ++ src/utils/kronos_kit/drop_path.py | 34 ++ src/utils/kronos_kit/image_funcs.py | 210 ++++++++ src/utils/kronos_kit/layer_scale.py | 27 + src/utils/kronos_kit/marker_metadata.py | 141 +++++ src/utils/kronos_kit/mlp.py | 40 ++ src/utils/kronos_kit/patch_embed.py | 100 ++++ src/utils/kronos_kit/swiglu_ffn.py | 72 +++ 39 files changed, 5569 insertions(+), 38 deletions(-) create mode 100644 celldownload.py create mode 100644 src/data/DeepCellData.py create mode 100644 src/data/EvaData.py create mode 100644 src/data/KRONOSData.py create mode 100644 src/models/DeepCellModel.py create mode 100644 src/models/EvaModel.py create mode 100644 src/models/KRONOSModel.py create mode 100644 src/utils/deepcell_kit/__init__.py create mode 100644 src/utils/deepcell_kit/config.py create mode 100644 src/utils/deepcell_kit/config/channel_mapping.yaml create mode 100644 src/utils/deepcell_kit/config/core_celltypes.yaml create mode 100644 src/utils/deepcell_kit/config/master_channels.yaml create mode 100644 src/utils/deepcell_kit/config/tissue_celltype_mapping_merged.yaml create mode 100644 src/utils/deepcell_kit/image_funcs.py create mode 100644 src/utils/deepcell_kit/utils.py create mode 100644 src/utils/download_utils/__init__.py create mode 100644 src/utils/download_utils/dct_download.py create mode 100644 src/utils/download_utils/kr_download.py create mode 100644 src/utils/eva_kit/__init__.py create mode 100644 src/utils/eva_kit/config.yaml create mode 100644 src/utils/eva_kit/constant.py create mode 100644 src/utils/eva_kit/global_properties.py create mode 100644 src/utils/eva_kit/image_funcs.py create mode 100644 src/utils/eva_kit/layers.py create mode 100644 src/utils/eva_kit/masking.py create mode 100644 src/utils/eva_kit/pos_embed.py create mode 100644 src/utils/kronos_kit/__init__.py create mode 100644 src/utils/kronos_kit/attention.py create mode 100644 src/utils/kronos_kit/block.py create mode 100644 src/utils/kronos_kit/dino_head.py create mode 100644 src/utils/kronos_kit/drop_path.py create mode 100644 src/utils/kronos_kit/image_funcs.py create mode 100644 src/utils/kronos_kit/layer_scale.py create mode 100644 src/utils/kronos_kit/marker_metadata.py create mode 100644 src/utils/kronos_kit/mlp.py create mode 100644 src/utils/kronos_kit/patch_embed.py create mode 100644 src/utils/kronos_kit/swiglu_ffn.py diff --git a/celldownload.py b/celldownload.py new file mode 100644 index 0000000..aad2b96 --- /dev/null +++ b/celldownload.py @@ -0,0 +1,48 @@ +import os +import argparse + +""" + +Model weights for EVA has to be downloaded from hugging face + https://huggingface.co/yandrewl/Eva +and renamed to: + eva_model.pt +and moved to: + out/models + +And GenePT marker embeddings from + https://zenodo.org/records/10833191 +Use the file: + GenePT_gene_protein_embedding_model_3_text.pickle +and store it as: + GenePT_embedding.pkl +in src.utils.eva_kit + +""" + +def parse_args(): + parser = argparse.ArgumentParser(description= + """ + Script for downloading image encoders from foundation models. + """) + parser.add_argument("--model",type=str,default="", + help=""" + Name of foundation model to download, available options are: + deepcell, kronos + """) + return parser.parse_args() + +def main(**args): + + model_path = os.path.join(os.getcwd(),"out","models") + + if args["model"] == "deepcell": + from src.utils.download_utils.dct_download import dct_download + dct_download(model_path) + if args["model"] == "kronos": + from src.utils.download_utils.kr_download import kr_download + kr_download(model_path) + +if __name__=="__main__": + args = vars(parse_args()) + main(**args) diff --git a/src/data/DeepCellData.py b/src/data/DeepCellData.py new file mode 100644 index 0000000..48e092b --- /dev/null +++ b/src/data/DeepCellData.py @@ -0,0 +1,86 @@ +import torch +from torch.utils.data import IterableDataset +import numpy as np +import yaml +import os +import pandas as pd +from tqdm import tqdm +from torch.utils.data import DataLoader +from skimage import io +from src.utils.deepcell_kit.config import DCTConfig + +class PatchDataset(): + """ + Dataset for single-image patchified data. + """ + def __init__(self, + root_dir='data/raw', + raw_subset_dir='', + split='train', + crop_factor=0.5, + n_clusters=1, + save_embed_data=False, + **args): + + assert split in ['train', 'test'], f'split must be either train or test, but is {split}' + + self.work_dir = os.path.join(os.getcwd(), root_dir, 'raw', raw_subset_dir) + self.img_dir = os.path.join(self.work_dir,split) + self.cells_path = [os.path.join(self.img_dir, p) for p in os.listdir(self.img_dir) if p.lower().endswith('_cells.npy')] + self.dct_config = DCTConfig() + + def create_attn_mask(self, sample, max_channels): + # True = padding + # https://pytorch.org/docs/stable/generated/torch.ao.nn.quantizable.MultiheadAttention.html#torch.ao.nn.quantizable.MultiheadAttention.forward + mask = np.full((sample.shape[0], max_channels), True) + mask[:, 0 : sample.shape[1]] = False + return mask + + def pad_images(self, sample, max_channels): + paddings = -1.0 # retrieved as a constant from repo (?) + return np.pad( + sample, + ((0, 0), (0, max_channels - sample.shape[1]), (0, 0), (0, 0), (0, 0)), + mode="constant", + constant_values=paddings, + ) + + def save_embed_data(self, model, device='cpu', batch_size=256): + """ + Save model representations of all cells per ROI. + + model (torch.Module): deepcell model + device (str): device to operate on + batch_size (int): Number of cells to extract representations from at once + """ + ch_idx_path = [os.path.join(self.work_dir, p) for p in os.listdir(self.work_dir) if p.split(os.sep)[-1] == 'channel_idx.npy'][0] + ch_idx = torch.from_numpy(np.load(ch_idx_path)) + model.eval() + with torch.no_grad(): + for path in tqdm(self.cells_path, desc='Save embeddings'): + + sample = torch.from_numpy(np.load(path)) # (B,C,3,H,W) + + attn_mask = self.create_attn_mask(sample, self.dct_config.MAX_NUM_CHANNELS) # (C_max,) + sample = self.pad_images(sample, self.dct_config.MAX_NUM_CHANNELS) # (C_max, 3, H, W) + sample, attn_mask = torch.as_tensor(sample, dtype=torch.float32), torch.as_tensor(attn_mask, dtype=torch.bool) + + embed = torch.empty((sample.shape[0], model.embed_size), dtype=torch.float32) + num_batches = (sample.shape[0]+batch_size-1) // batch_size + + for batch_idx in range(num_batches): + if batch_idx < num_batches - 1: + idx_start = batch_idx*batch_size + idx_end = batch_idx*batch_size+batch_size + else: + idx_start = batch_idx*batch_size + idx_end = len(embed) + embed[idx_start:idx_end] = model( + sample[idx_start:idx_end].to(device), + ch_idx.to(device), + attn_mask[idx_start:idx_end].to(device) + ).to('cpu') + torch.save(embed, os.path.join(path, path.split('.')[0]+'_embed.pt')) + del sample + del attn_mask + del embed diff --git a/src/data/EvaData.py b/src/data/EvaData.py new file mode 100644 index 0000000..b87d638 --- /dev/null +++ b/src/data/EvaData.py @@ -0,0 +1,69 @@ +import numpy as np +import os +import pandas as pd +from tqdm import tqdm +import torch +from src.utils.eva_kit.constant import marker_to_gene + +class EvaDataset(): + """ + Dataset for single-image patchified data. + """ + def __init__(self, + conf, + root_dir='data/raw', + raw_subset_dir='', + split='train', + **args): + + assert split in ['train', 'test'], f'split must be either train or test, but is {split}' + + self.work_dir = os.path.join(os.getcwd(), root_dir, 'raw', raw_subset_dir) + self.img_dir = os.path.join(self.work_dir,split) + self.cells_path = [os.path.join(self.img_dir, p) for p in os.listdir(self.img_dir) if p.lower().endswith('_cells.npy')] + self.conf = conf + + channel_names = args['channel_names'] + channel_mask = [True]*len(channel_names) + for idx, channel_name in enumerate(channel_names): + if channel_name not in marker_to_gene.keys(): + print(f"WARNING! {channel_name} is not in GenePT embeddings and will be masked!") + channel_mask[idx] = False + self.channel_names = np.array(channel_names)[channel_mask] + + def save_embed_data(self, model, device='cpu', batch_size=256): + """ + Save model representations of all cells per ROI. + + model (torch.Module): KRONOS model + device (str): device to operate on + batch_size (int): Number of cells to extract representations from at once + """ + + model.eval() + with torch.no_grad(): + for path in tqdm(self.cells_path, desc='Save embeddings'): + if not os.path.exists(os.path.join(path,path.split('.')[0]+'_embed.pt')): + sample = torch.from_numpy(np.load(path)) # (B, C, H, W) + embed = torch.empty((sample.shape[0],self.conf.pm.dim), dtype=torch.float32) + bms = [self.channel_names.copy() for _ in range(sample.shape[0])] + + num_batches = (sample.shape[0]+batch_size-1) // batch_size + for batch_idx in range(num_batches): + if batch_idx < num_batches - 1: + idx_start = batch_idx*batch_size + idx_end = batch_idx*batch_size+batch_size + else: + idx_start = batch_idx*batch_size + idx_end = len(embed) + + image_out, _ = model.model.forward_encoder(sample[idx_start:idx_end].to(device), bms[idx_start:idx_end]) + image_cls = image_out[:,:,0,:] + image_cls = image_cls.squeeze(1) + batch_size = image_cls.size(0) + feat = image_cls.view(batch_size, -1) + embed[idx_start:idx_end] = feat + + torch.save(embed, os.path.join(path, path.split('.')[0]+'_embed.pt')) + del sample + del embed diff --git a/src/data/KRONOSData.py b/src/data/KRONOSData.py new file mode 100644 index 0000000..a5bd613 --- /dev/null +++ b/src/data/KRONOSData.py @@ -0,0 +1,69 @@ +import numpy as np +import os +import pandas as pd +from tqdm import tqdm +import torch + +class KRONOSDataset(): + """ + Dataset for single-image patchified data. + """ + def __init__(self, + root_dir='data/raw', + raw_subset_dir='', + split='train', + **args): + + assert split in ['train', 'test'], f'split must be either train or test, but is {split}' + + self.work_dir = os.path.join(os.getcwd(), root_dir, 'raw', raw_subset_dir) + self.img_dir = os.path.join(self.work_dir,split) + self.cells_path = [os.path.join(self.img_dir, p) for p in os.listdir(self.img_dir) if p.lower().endswith('_cells.npy')] + + try: + marker_df = pd.read_csv( + os.path.join(os.getcwd(),'data','marker_info_with_metadata.csv') + ) + except Exception: + print('Warning! marker_info_with_metadata.csv not found, run image_preprocess for KRONOS first') + self.marker_ids = marker_df['marker_id'] + + def save_embed_data(self, model, device='cpu', batch_size=256): + """ + Save model representations of all cells per ROI. + + model (torch.Module): KRONOS model + device (str): device to operate on + batch_size (int): Number of cells to extract representations from at once + """ + + model.eval() + with torch.no_grad(): + for path in tqdm(self.cells_path, desc='Save embeddings'): + sample = torch.from_numpy(np.load(path)) + marker_ids = torch.from_numpy(np.repeat( + np.expand_dims(self.marker_ids, axis=0), + repeats=sample.shape[0],axis=0)) + embed = torch.empty((sample.shape[0],model.embed_dim), dtype=torch.float32) + num_batches = (sample.shape[0]+batch_size-1) // batch_size + for batch_idx in range(num_batches): + if batch_idx < num_batches - 1: + idx_start = batch_idx*batch_size + idx_end = batch_idx*batch_size+batch_size + else: + idx_start = batch_idx*batch_size + idx_end = len(embed) + + patch_features, patch_marker_features, patch_token_features = model( + sample[idx_start:idx_end].to(device,dtype=torch.float32), + marker_ids=marker_ids[idx_start:idx_end].to(device,dtype=torch.int64) + ) + embed[idx_start:idx_end] = patch_features.to('cpu') + del patch_features + del patch_marker_features + del patch_token_features + + torch.save(embed, os.path.join(path, path.split('.')[0]+'_embed.pt')) + del sample + del marker_ids + del embed diff --git a/src/models/DeepCellModel.py b/src/models/DeepCellModel.py new file mode 100644 index 0000000..e303ef2 --- /dev/null +++ b/src/models/DeepCellModel.py @@ -0,0 +1,292 @@ +import numpy as np +import torch +import torch.nn as nn + +""" +Generalized cell phenotyping for spatial proteomics with language-informed vision models +Wang XJ, Dilip R, Iqbal AR, Bussi Y, Brown C, Pradhan E, Jain Y, Yu K, Li S, Abt M, +Börner K, Keren L, Yue Y, Barnowski R, Van Valen D. + +https://pubmed.ncbi.nlm.nih.gov/39605651/ +https://github.com/vanvalenlab/deepcell-types +""" + + +class GradReverse(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.neg() + + +def grad_reverse(x): + return GradReverse.apply(x) + + +class CellTypeClassificationHead(nn.Module): + def __init__(self, n_filters, n_celltypes, dropout_rate=0.1): + super(CellTypeClassificationHead, self).__init__() + self.dense1 = nn.Linear(n_filters, n_filters) + self.dense2 = nn.Linear(n_filters, n_filters // 2) + self.dense3 = nn.Linear( n_filters // 2, n_celltypes) + self.silu = nn.SiLU() + self.dropout = nn.Dropout(dropout_rate) + + def forward(self, x): + out = x + out = self.dense1(out) + out = self.silu(out) + out = self.dropout(out) + out = self.dense2(out) + out = self.silu(out) + out = self.dropout(out) + out = self.dense3(out) + return out + + +class DomainClassificationHead(nn.Module): + """For domain adaptation, reverse the gradient.""" + def __init__(self, n_filters, n_domains, dropout_rate=0.1): + super(DomainClassificationHead, self).__init__() + self.dense1 = nn.Linear(n_filters, n_filters) + self.dense2 = nn.Linear(n_filters, n_filters // 2) + self.dense3 = nn.Linear(n_filters // 2, n_domains) + self.silu = nn.SiLU() + self.dropout = nn.Dropout(dropout_rate) + self.layer_norm1 = nn.BatchNorm1d(n_filters) + self.layer_norm2 = nn.BatchNorm1d(n_filters // 2) + self.layer_norm3 = nn.BatchNorm1d(n_domains) + + def forward(self, x): + out = x + out = grad_reverse(out) + out = self.dense1(out) + out = self.layer_norm1(out) + out = self.silu(out) + out = self.dropout(out) + out = self.dense2(out) + out = self.layer_norm2(out) + out = self.silu(out) + out = self.dropout(out) + out = self.dense3(out) + return out + + +class ConvBlock(nn.Module): + """ Simple Convolutional block for feature extraction """ + def __init__(self, n_filters): + super(ConvBlock, self).__init__() + self.n_filters = n_filters + + self.layers = nn.Sequential( + nn.Conv2d(3, n_filters//16, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(n_filters//16), + nn.SiLU(), + nn.Conv2d(n_filters//16, n_filters//16, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(n_filters//16), + nn.SiLU(), + nn.Conv2d(n_filters//16, n_filters//8, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(n_filters//8), + nn.SiLU(), + nn.Conv2d(n_filters//8, n_filters//8, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(n_filters//8), + nn.SiLU(), + nn.Conv2d(n_filters//8, n_filters//4, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(n_filters//4), + nn.SiLU(), + nn.Conv2d(n_filters//4, n_filters//4, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(n_filters//4), + nn.SiLU(), + nn.Conv2d(n_filters//4, n_filters//2, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(n_filters//2), + nn.SiLU(), + nn.Conv2d(n_filters//2, n_filters//2, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(n_filters//2), + nn.SiLU(), + nn.Conv2d(n_filters//2, n_filters, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(n_filters), + nn.SiLU(), + nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(n_filters), + nn.SiLU(), + nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(n_filters), + nn.SiLU(), + ) + + def forward(self, x): + # merget first two dimensions (B, C, 3, H, W) -> (B * C, 3, H, W) + out = x.view(-1, 3, x.shape[-2], x.shape[-1]) + out = self.layers(out) + + # reshape back to original shape + assert out.shape[-2] == out.shape[-1] == 1 # spatial dimensions are 1 + out = out.view(x.shape[0], x.shape[1], self.n_filters) + return out + + +class MarkerNameEmbeddingLayer(nn.Module): + """Load pre-trained embeddings for marker names, then apply a linear layer.""" + def __init__(self, n_filters, marker_embeddings): + super(MarkerNameEmbeddingLayer, self).__init__() + + embeddings = torch.cat( + [ + torch.zeros(1, marker_embeddings.shape[1]), # padding + torch.as_tensor(marker_embeddings), + ], + dim=0, + ) + + self.embed_layer = nn.Embedding.from_pretrained( + embeddings, freeze=True, padding_idx=0 + ) + self.dense = nn.Linear(embeddings.shape[1], n_filters) + + def forward(self, x): + out = x + 1 # shift by 1 to account for padding + out = self.embed_layer(out) + out = self.dense(out) + return out + + + +class CellTypeDataEncoder(nn.Module): + """ Encode cell type data, including marker names and images. """ + def __init__(self, n_filters, n_heads, marker_embeddings, img_feature_extractor, + embed_size=256 # jonte edit, temporary + # is this supposed to be at + # n_filters in nn.Linear? + ): + super(CellTypeDataEncoder, self).__init__() + self.n_heads = n_heads + self.embed_size = embed_size +# self.n_celltypes = n_celltypes +# self.n_domains = n_domains + + # Define marker name embedding layer + self.marker_embedder = MarkerNameEmbeddingLayer(n_filters, marker_embeddings) + + # Define CLS token + self.cls_token = nn.Parameter(torch.randn(1, 1, n_filters)) + + # Define blocks + self.img_feature_extractor = ConvBlock(n_filters) + + self.transformer_blocks = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=n_filters, nhead=n_heads, dim_feedforward=n_filters*2, batch_first=True + ), + num_layers=5, + ) +# self.classification_head = CellTypeClassificationHead(n_filters, n_celltypes) +# self.domain_classification_head = DomainClassificationHead(n_filters, n_domains) + self.cls_single_attention = nn.MultiheadAttention(n_filters, num_heads=1, dropout=0.0, batch_first=True) + + self.marker_positivity_head = nn.Linear(n_filters, 1) + + def forward(self, inputs_app, inputs_ch_names, inputs_ch_padding_masks): + """ + inputs_app: (B, C, 3, H, W) + inputs_ch_padding_mask: (B, C), True=ignore + """ + aug_inputs_ch_padding_masks = nn.functional.pad( + inputs_ch_padding_masks.long(), (1, 0), mode="reflect" + ).bool() # (B, C+1) - add padding for CLS token + + # Apply convolutions + x = self.img_feature_extractor(inputs_app) # (B, C, n_filters) + + # Create marker name embeddings + marker_embeddings = self.marker_embedder(inputs_ch_names) + + if self.training: + # Add noise to marker name embeddings + marker_embeddings = marker_embeddings + torch.randn_like(marker_embeddings) * 0.005 + # Normalize marker embeddingste + marker_embeddings = marker_embeddings / marker_embeddings.norm(dim=-1, keepdim=True) + + x = x + marker_embeddings + + # Apply transformer (w/o CLS token) + x = self.transformer_blocks(x, src_key_padding_mask=inputs_ch_padding_masks) + + # Add CLS token + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat([cls_tokens, x], dim=1) # (B, C+1, n_filters) + + # Apply Single attention layer + x, attention = self.cls_single_attention(x, x, x, key_padding_mask=aug_inputs_ch_padding_masks, need_weights=True, average_attn_weights=False) + + # Take the CLS token embedding out + cls_token_embedding = x[:, 0, :] # (B, n_filters) + + # Apply classification heads +# celltype_output = self.classification_head(cls_token_embedding) +# domain_output = self.domain_classification_head(cls_token_embedding) + +# return celltype_output, domain_output, cls_token_embedding, attention[:, 0, 0, 1:] + return cls_token_embedding + + +class CellTypeCLIPModel(nn.Module): + """ Apply contrastive learning to data against cell type names. """ + def __init__(self, n_filters, embedding_dim, ct_embeddings, marker_embeddings, n_heads,n_celltypes, n_domains, img_feature_extractor="conv"): + super(CellTypeCLIPModel, self).__init__() + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.ct_embedding = nn.Embedding.from_pretrained( + torch.as_tensor(ct_embeddings), freeze=True + ) + + self.image_encoder = CellTypeDataEncoder( + n_filters=n_filters, + n_heads=n_heads, + n_celltypes=n_celltypes, + n_domains=n_domains, + marker_embeddings=marker_embeddings, + img_feature_extractor=img_feature_extractor, + ) + + self.image_adaptor = nn.Sequential( + nn.Linear(n_filters, n_filters), + ) + self.text_adaptor = nn.Linear(embedding_dim, n_filters) + + + def forward(self, sample, ch_idx, mask, ct_exclude=None): + + # Encode image + _, _, cls_token_embedding, marker_pos_attn = self.image_encoder( + sample, ch_idx, mask + ) + image_embedding = cls_token_embedding + image_embedding = self.image_adaptor(image_embedding) + + image_embedding = image_embedding / image_embedding.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + + # extract probabilities for each image + raw_text_embedding_all_classes = self.ct_embedding.weight # shape = [n_celltypes, embedding_dim] + text_embedding_all_classes = self.text_adaptor(raw_text_embedding_all_classes) + text_embedding_all_classes = text_embedding_all_classes / text_embedding_all_classes.norm(dim=-1, keepdim=True) + logits_per_image_all_classes = logit_scale * image_embedding @ text_embedding_all_classes.t() + + if ct_exclude is not None: + for i in range(len(ct_exclude)): + for j in range(len(ct_exclude[i])): + logits_per_image_all_classes[i][ct_exclude[i][j]] = -1e4 + + + probs = torch.softmax(logits_per_image_all_classes, dim=-1) # shape = [global_batch_size, n_celltypes] + + # normalize marker_pos_attn by max value + marker_pos_attn = marker_pos_attn / torch.max(marker_pos_attn, dim=-1, keepdim=True)[0] + + return None, None, None, marker_pos_attn, probs, image_embedding diff --git a/src/models/EvaModel.py b/src/models/EvaModel.py new file mode 100644 index 0000000..095e85c --- /dev/null +++ b/src/models/EvaModel.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2025/05/04 17:30 +@Author : Yufan Liu +@Desc : Two-stage Eva MAE implementation + Also see: https://github.com/facebookresearch/mae/blob/main/models_mae.py +""" + +import pickle +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +import os + +from src.utils.eva_kit.layers import MarkerEmbeddingGenePT +from src.utils.eva_kit.layers import MaskedBlock as Block +from src.utils.eva_kit.layers import PatchEmbedChannelFree +from src.utils.eva_kit.masking import random_masking +from src.utils.eva_kit.pos_embed import get_2d_sincos_pos_embed + +class EvaMAE(nn.Module): + def __init__(self, conf): + super().__init__() + self.conf = conf + self.token_size = conf.ds.token_size + self.img_size = conf.ds.patch_size + self.model = MaskedAutoencoderViT(conf) + + def forward(self, img, marker_in, channel_mask=None, marker_out=None, infer_mask=None): + img = img.permute(0, 3, 1, 2) + image_recon_cls, raw_mask = self.model.forward( + imgs=img, marker_in=marker_in, channel_mask=channel_mask, marker_out=marker_out, infer_mask=infer_mask + ) + image_recon = image_recon_cls[:, :, 1:, :] + image_cls = image_recon_cls[:, :, 0, :] + + image_recon = rearrange( + image_recon, + "N C (H W) (P1 P2) -> N (H P1) (W P2) C", + P1=self.token_size, + P2=self.token_size, + H=self.img_size // self.token_size, + N=image_recon_cls.shape[0], + ) + + return image_recon, image_cls, raw_mask + +class MaskedAutoencoderViT(nn.Module): + """Masked Autoencoder with Vision Transformer for spatial transcriptomics. + + Processes multi-channel spatial data with channel-specific marker embeddings. + Architecture: Channel Mixer -> Patch Mixer -> Decoder. + + Args: + conf: Configuration object with model hyperparameters for dataset, encoder, decoder, + channel mixer, and patch mixer components. + """ + + def __init__(self, conf): + super().__init__() + self.conf = conf + # ---------------------------------------------------------------------------- # + # --------------------------- I Encoder components --------------------------- # + # ---------------------------------------------------------------------------- # + + # ----------------------------- 1. Channel Former ---------------------------- # + + # ------------------------ patchify and embed patches ------------------------ # + self.patch_embed = PatchEmbedChannelFree( + img_size=conf.ds.patch_size, token_size=conf.ds.token_size, embed_dim=conf.ds.token_size**2 + ) + # ---------------------------------------------------------------------------- # + self.num_patches = self.patch_embed.num_patches + self.channel_proj = nn.Sequential( + nn.Linear(conf.ds.token_size**2, conf.cm.dim * 2), + nn.LayerNorm(conf.cm.dim * 2), + nn.GELU(), + nn.Linear(conf.cm.dim * 2, conf.cm.dim), + nn.LayerNorm(conf.cm.dim), + ) + self.channel_enc_blocks = nn.ModuleList( + [ + Block( + dim=conf.cm.dim, + num_heads=conf.cm.n_heads, + mlp_ratio=conf.cm.mlp_ratio, + qkv_bias=True, + norm_layer=nn.LayerNorm, + ) + for _ in range(conf.cm.n_layers) + ] + ) + self.channel_norm = nn.LayerNorm(conf.cm.dim) + + # -------------------------- Masker and strategies -------------------------- # + self.mask_strategy = conf.ds.mask_strategy + self.mask_ratio = conf.ds.mask_ratio + if self.mask_strategy == "specified": + # For specified strategy, we need to pass channels parameter during call + self.masker = random_masking(self.mask_ratio, self.mask_strategy) + self.mask_channels = list(getattr(conf.ds, "mask_channels")) + else: + self.masker = random_masking(self.mask_ratio, self.mask_strategy) + # ---------------------------------------------------------------------------- # + + # ----------------------------- Marker Embeddings ---------------------------- # + self.marker_cls_token = nn.Parameter(torch.zeros(1, 1, conf.cm.dim)) + self.marker_dim = conf.ds.marker_dim + marker_dict = pickle.load(open(os.path.join(os.getcwd(),"src","utils","eva_kit","GenePT_embedding.pkl"), "rb")) + self.marker_embed = MarkerEmbeddingGenePT(marker_dict, self.marker_dim) + self.marker_proj = nn.Sequential( + nn.Linear(self.marker_dim, conf.cm.dim), + nn.LayerNorm(conf.cm.dim), + ) + # ---------------------------------------------------------------------------- # + + # ------------------------------ 2. Patch Former ----------------------------- # + self.linker_proj = nn.Sequential( + nn.Linear(conf.cm.dim, conf.cm.dim * 2), + nn.LayerNorm(conf.cm.dim * 2), + nn.GELU(), + nn.Linear(conf.cm.dim * 2, conf.pm.dim), + nn.LayerNorm(conf.pm.dim), + ) # project Dim from channel to patch + + # Positional embeddings and tokens + self.cls_token = nn.Parameter( + torch.zeros(1, 1, conf.pm.dim) + ) # share with decoder as patch-level representation + self.enc_pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches + 1, conf.pm.dim), requires_grad=False + ) # Fixed sin-cos embedding + + # Encoder transformer blocks + self.patch_enc_blocks = nn.ModuleList( + [ + Block( + dim=conf.pm.dim, + num_heads=conf.pm.n_heads, + mlp_ratio=conf.pm.mlp_ratio, + qkv_bias=True, + norm_layer=nn.LayerNorm, + ) + for _ in range(conf.pm.n_layers) + ] + ) + self.enc_norm = nn.LayerNorm(conf.pm.dim) + self.enc_proj = nn.Sequential( + nn.Linear(conf.pm.dim, conf.pm.out_dim * 2), + nn.LayerNorm(conf.pm.out_dim * 2), + nn.GELU(), + nn.Linear(conf.pm.out_dim * 2, conf.de.dim), + nn.LayerNorm(conf.de.dim), + ) + + # ---------------------------------------------------------------------------- # + + # ---------------------------------------------------------------------------- # + # --------------------------- II Decoder components -------------------------- # + # ---------------------------------------------------------------------------- # + + # self.decoder_embed = nn.Linear(conf.pm.dim, conf.de.dim, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, conf.de.dim)) # match with BCND + + # ---------------- Positional encoding for patch-level flatten -------------- # + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, conf.de.dim), requires_grad=False) + self.flatten_dim_mapper = "(B C) N D" + # ---------------------------------------------------------------------------- # + + # Decoder transformer blocks + self.decoder_blocks = nn.ModuleList( + [ + Block( + dim=conf.de.dim, + num_heads=conf.de.n_heads, + mlp_ratio=conf.de.mlp_ratio, + qkv_bias=True, + norm_layer=nn.LayerNorm, + ) + for _ in range(conf.de.n_layers) + ] + ) + + self.decoder_norm = nn.LayerNorm(conf.de.dim) + self.decoder_pred = nn.Linear(conf.de.dim, conf.ds.token_size**2, bias=True) + # ---------------------------------------------------------------------------- # + self.initialize_weights() + + def initialize_weights(self) -> None: + """Initialize model weights using standard transformer initialization.""" + # Initialize position embeddings + pos_embed = get_2d_sincos_pos_embed( + self.enc_pos_embed.shape[-1], int(self.patch_embed.num_patches**0.5), cls_token=True + ) + self.enc_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize decoder position embeddings (patch-only) + pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**0.5), cls_token=True + ) + self.decoder_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch embedding weights + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # Initialize tokens + torch.nn.init.normal_(self.cls_token, std=0.02) + torch.nn.init.normal_(self.mask_token, std=0.02) # all along the whole model + torch.nn.init.normal_(self.marker_cls_token, std=0.02) + + # Initialize other layers + self.apply(self._init_weights) + + def _init_weights(self, m: nn.Module) -> None: + """Initialize weights for a specific module. + + Args: + m: Module to initialize (Linear or LayerNorm). + """ + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _embed_marker(self, markers: list[list], expand_num=None): + """Embed markers and expand to patch dimensions. + + Args: + markers: List of marker lists for each channel in batch + expand_num: Number of patches to expand to. Defaults to self.num_patches. + + Returns: + Expanded marker embeddings [B, C, N, marker_dim] + """ + m_embed = [] + for m in markers: + m_embed.append(self.marker_embed(m)) + m_embed = torch.stack(m_embed, dim=0) # [B, C, marker_dim] + if not expand_num: + expand_num = self.num_patches + + m_embed = m_embed.unsqueeze(2).expand(-1, -1, expand_num, -1) + return m_embed + + def channel_forward(self, image, marker, channel_mask=None, infer_mask=None): + """Process input through channel mixer with masking. + + Args: + image: Input images [B, C, H, W] + marker: Marker embeddings for each channel + infer_mask: Pre-defined mask for inference. Defaults to None. + + Returns: + Tuple of (processed features [B, N, C+1, D], raw mask [C, N] or [B, C, N]) + """ + # define a channel former forward process + # image: [B, C, H, W] + B, C, H, W = image.shape + x = self.patch_embed(image) # [B, N, P*P*C], with channel flatten, [B, C, N, P*P] with channel-agnostic + + x = self.channel_proj(x) # [B, C, N, D] + + # Add marker embeddings + marker_embeddings = self._embed_marker(marker).to(x.device) + marker_embeddings = self.marker_proj(marker_embeddings) + x = x + marker_embeddings + + # --------------------------- Generate mask matrix --------------------------- # + if infer_mask is not None: + raw_mask = infer_mask + else: + if self.mask_strategy == "specified": + raw_mask = self.masker(x, self.mask_channels) # Pass channels parameter + else: + raw_mask = self.masker(x) # [C, N] or [B, C, N] + + # channel_mask if provided (for padding) + if channel_mask is not None: + # channel_mask: [B, C] -> [B, C, N] to match raw_mask + channel_mask = channel_mask[..., None].expand(-1, -1, self.num_patches) # [B, C] to [B, C, N] + # Combine channel_mask with raw_mask: if either is 1, result is 1 + if raw_mask.dim() == 2: # [C, N] + # Expand raw_mask to [B, C, N] to match channel_mask + raw_mask = raw_mask[None, ...].expand(B, -1, -1) # [B, C, N] + # Now both are [B, C, N], combine them + raw_mask = torch.logical_or(raw_mask.bool(), channel_mask.bool()).float() + + if raw_mask.dim() == 2: # [C, N] + cls_mask = torch.zeros([C + 1, self.num_patches], device=x.device) + cls_mask[1:, :] = raw_mask # [C+1, N] + cls_mask = cls_mask[None, ...].permute(0, 2, 1).expand(B, -1, -1) # [B, N, C+1] + cls_mask = cls_mask.reshape(-1, C + 1) + attn_mask = (cls_mask.unsqueeze(1) + cls_mask.unsqueeze(2)).clamp(max=1) # [B*N, C+1, C+1] + attn_mask = attn_mask.unsqueeze(1) # [B*N, 1, C+1, C+1], broadcast to each head + else: + raise ValueError(f"raw_mask shape not supported: {raw_mask.shape}") + # ---------------------------------------------------------------------------- # + + # replace invisible tokens as a unified value, and mask them out in attention layers. + mask_token = self.mask_token.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3]) + if raw_mask.dim() == 2: + x = torch.where(raw_mask[None, ..., None] == 1, mask_token.to(x.device), x) + else: + raise ValueError(f"raw_mask shape not supported: {raw_mask.shape}") + + x = rearrange(x, "B C N D -> (B N) C D") + cls_tokens = self.marker_cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # Apply channel transformer blocks + for block in self.channel_enc_blocks: + x = block(x, attn_mask=attn_mask) # Process channel-wise relationships + x = self.channel_norm(x) + + # The return x value will contain cls token, use or not in patch-former + x = rearrange(x, "(B N) C D -> B N C D", B=B, C=C + 1) + return x, raw_mask + + def patch_forward(self, input_x): + """Process features through patch mixer. + + Args: + input_x: Features from channel mixer [B, N, C + 1, D] + + Returns: + Processed features [B, C, N+1, D] (before enc_proj) + """ + + x = input_x[:, :, 0:1, :] + x = self.linker_proj(x) + B, N, C, D = x.shape + x = rearrange(x, "B N C D -> (B C) N D", B=B, N=N, C=C, D=D) + cls_token = self.cls_token.repeat(x.shape[0], 1, 1) + x = torch.cat([x, cls_token], dim=1) + x = x + self.enc_pos_embed + + for blk in self.patch_enc_blocks: + x = blk(x) + x = self.enc_norm(x) + x = rearrange(x, "(B C) N D -> B C N D", B=B, C=C, N=N + 1) + return x + + def forward_encoder(self, image, marker, channel_mask=None, infer_mask=None): + """Complete forward pass through encoder. + + Args: + image: Input images [B, C, H, W] + marker: Marker embeddings for each channel + infer_mask: Pre-defined mask for inference. Defaults to None. + + Returns: + Tuple of (encoded features [B, C, N+1, D], raw mask [C, N] or [B, C, N]) + """ + x, raw_mask = self.channel_forward(image, marker, channel_mask, infer_mask) + x = self.patch_forward(x) + return x, raw_mask + + def forward_decoder(self, x: torch.Tensor, marker, channel_mask=None) -> torch.Tensor: + """Reconstruct input from encoded representation. + + Args: + x: Encoded representation [B, C, N+1, D] (before enc_proj) + marker: Marker embeddings for each channel + + Returns: + Reconstructed patches [B, C, N+1, token_size**2] + """ + x = self.enc_proj(x) + B, C, N, D = x.shape + x = x.repeat(1, len(marker[0]), 1, 1) + C = len(marker[0]) + marker_embeddings = self._embed_marker(marker, expand_num=N).to(x.device) + marker_embeddings = self.marker_proj(marker_embeddings) + x = x + marker_embeddings + + # --------------------------- Generate attention mask for decoder --------------------------- # + if channel_mask is not None: + # channel_mask: [B, C] -> [B, C, N] to match the decoder input shape + channel_mask = channel_mask[..., None].expand(-1, -1, N) # [B, C] to [B, C, N] + # For patch flatten: [B, C, N] -> [B*C, N] -> [B*C*N, N] + channel_mask = channel_mask.reshape(B * C, N) # [B*C, N] + channel_mask = channel_mask.reshape(-1, N) # [B*C*N, N] + attn_mask = (channel_mask.unsqueeze(1) + channel_mask.unsqueeze(2)).clamp(max=1) # [B*C*N, N, N] + attn_mask = attn_mask.unsqueeze(1) # [B*C*N, 1, N, N] + else: + attn_mask = None + # ---------------------------------------------------------------------------- # + + x = rearrange(x, f"B C N D -> {self.flatten_dim_mapper}") + # project encoder output to match decoder + x = x + self.decoder_pos_embed + + for blk in self.decoder_blocks: + x = blk(x, attn_mask=attn_mask) + x = self.decoder_norm(x) + x = self.decoder_pred(x) + x = rearrange(x, f"{self.flatten_dim_mapper} -> B C N D", N=N, C=C, B=B) + return x + + def forward( + self, imgs: torch.Tensor, marker_in, channel_mask=None, marker_out=None, infer_mask=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Complete forward pass through MAE model. + + Args: + imgs: Input images [B, C, H, W] + marker_in: Input marker embeddings for encoding + marker_out: Output marker embeddings for decoding. Defaults to marker_in. + infer_mask: Pre-defined mask for inference. Defaults to None. + + Returns: + Tuple of (reconstructed patches [B, C, N+1, token_size**2], mask [C, N] or [B, C, N]) + """ + encoder_x, raw_mask = self.forward_encoder( + image=imgs, marker=marker_in, channel_mask=channel_mask, infer_mask=infer_mask + ) + if not marker_out: + marker_out = marker_in + pred = self.forward_decoder(encoder_x, marker_out, channel_mask) + return pred, raw_mask diff --git a/src/models/KRONOSModel.py b/src/models/KRONOSModel.py new file mode 100644 index 0000000..62b7f70 --- /dev/null +++ b/src/models/KRONOSModel.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import numpy as np +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from src.utils.kronos_kit.dino_head import DINOHead +from src.utils.kronos_kit.mlp import Mlp +from src.utils.kronos_kit.patch_embed import PatchEmbed +from src.utils.kronos_kit.swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from src.utils.kronos_kit.block import NestedTensorBlock as Block +from src.utils.kronos_kit.attention import MemEffAttention + +logger = logging.getLogger("dinov2") + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +# -------------------------------------------------------- +# 1D sine-cosine position embedding +# References: +# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py +# -------------------------------------------------------- +def get_1d_sincos_marker_embed(embed_dim, max_marker_id, cls_token=False): + """ + max_marker_id: marker ids length + return: + marker_embed: [max_marker_id, embed_dim] or [1+max_marker_id, embed_dim] (w/ or w/o cls_token) + """ + ids = np.arange(max_marker_id, dtype=float) + marker_embed = get_1d_sincos_marker_embed_from_grid(embed_dim, ids) + if cls_token: + marker_embed = np.concatenate([np.zeros([1, embed_dim]), marker_embed], axis=0) + return marker_embed + +def get_1d_sincos_marker_embed_from_grid(embed_dim, ids): + """ + embed_dim: output dimension for each marker + ids: a list of marker ids to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + ids = ids.reshape(-1) # (M,) + out = np.einsum('m,d->md', ids, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x, return_attention=False): + for b in self: + if isinstance(b, Block): + x = b(x, return_attention) + break + else: + x = b(x) + + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + stride_size=16, + num_markers=512, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.stride_size = stride_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=1, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + # -- modality embedding + self.marker_embed = get_1d_sincos_marker_embed(embed_dim, num_markers, cls_token=False) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h, npatch): + # h0 = h // self.patch_size + # w0 = w // self.patch_size + + h0_stride = int(np.sqrt(npatch)) + w0_stride = int(np.sqrt(npatch)) + + previous_dtype = x.dtype + N = self.pos_embed.shape[1] - 1 + if npatch == N and w==h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + # w0 = w // self.patch_size + # h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0_stride + self.interpolate_offset) / M + sy = float(h0_stride + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0_stride, h0_stride) + + # patch_pos_embed = patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2) + # patch_pos_embed = nn.functional.interpolate(patch_pos_embed, mode="bicubic", antialias=self.interpolate_antialias, **kwargs) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0_stride, h0_stride) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1) + # patch_pos_embed = patch_pos_embed.reshape(1, -1, dim) + + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None, marker_ids=None): + B, num_marker, w, h = x.shape + x = self.patch_embed(x) + num_patches = int(x.shape[1] / num_marker) + + + # selecting marker embeddings based on marker_index + assert marker_ids is not None, "marker_ids should be provided" + marker_embed = torch.from_numpy(self.marker_embed).float().unsqueeze(0).to(device=x.device, dtype=x.dtype) + marker_embed = marker_embed.repeat(B, 1, 1) + marker_embed = apply_masks(marker_embed, marker_ids) + marker_embed = torch.repeat_interleave(marker_embed, num_patches, 1) + + # adding selected marker embeddings to patch embeddings + x = x + marker_embed + + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + pos_embed = self.interpolate_pos_encoding(x, w, h, num_patches) + pos_embed = torch.cat((pos_embed[:, 0, :].unsqueeze(0), pos_embed[:, 1:, :].repeat(1, num_marker, 1)), dim=1) + x = x + pos_embed + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list, marker_ids_list): + x = [self.prepare_tokens_with_masks(x, masks, marker_ids) for x, masks, marker_ids in zip(x_list, masks_list, marker_ids_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, marker_ids=None): + if isinstance(x, list): + return self.forward_features_list(x, masks, marker_ids) + + x = self.prepare_tokens_with_masks(x, masks, marker_ids) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1: self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1:], + "x_prenorm": x, + "masks": masks, + } + + def forward(self, x, masks=None, marker_ids=None, is_training=False): + if marker_ids is None: + marker_ids = [torch.tensor([i+4 for i in range(x.shape[1])], device=x.device) for _ in range(x.shape[0])] + ret = self.forward_features(x, masks, marker_ids) + if is_training: + return ret + else: + B, num_marker, w, h = x.shape + tokens_per_row = len([i for i in range(0, h-self.patch_size+1, self.stride_size)]) + tokens_per_col = len([i for i in range(0, w-self.patch_size+1, self.stride_size)]) + + patch_features = ret["x_norm_clstoken"] + patch_token_features = ret["x_norm_patchtokens"].reshape(B, num_marker, tokens_per_row, tokens_per_col, self.embed_dim) + patch_marker_features = torch.mean(torch.mean(patch_token_features, dim=-2), dim=-2) + + return patch_features, patch_marker_features, patch_token_features + + def _get_intermediate_layers_not_chunked(self, x, marker_ids, n=1): + x = self.prepare_tokens_with_masks(x, masks=None, marker_ids=marker_ids) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, marker_ids, n=1): + x = self.prepare_tokens_with_masks(x, masks=None, marker_ids=marker_ids) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, marker_ids, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, marker_ids, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, marker_ids, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def get_last_self_attention(self, x, masks=None, marker_ids=None): + if isinstance(x, list): + return self.forward_features_list(x, masks, marker_ids) + + x = self.prepare_tokens_with_masks(x, masks, marker_ids) + + # Run through model, at the last block just return the attention. + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + return blk(x, return_attention=True) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/src/models/__init__.py b/src/models/__init__.py index e69de29..cc26f44 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -0,0 +1,52 @@ +def DeepCell( + n_filters=256, + n_heads=4, + n_domains=9, + embed_size=256, + **args + ): + + from src.utils.deepcell_kit.config import DCTConfig + import numpy as np + from src.models.DeepCellModel import CellTypeDataEncoder + + dct_config = DCTConfig() + + embedding_model_name = "deepseek-r1-70b-llama-distill-q4_K_M" + + marker2embedding = dct_config.get_channel_embedding( + embedding_model_name=embedding_model_name + ) + + marker_embeddings = np.zeros_like(list(marker2embedding.values()), dtype=np.float32) + for marker, ebd in marker2embedding.items(): + if marker not in dct_config.marker2idx: + print("bad_marker?", marker) + idx = dct_config.marker2idx[marker] + marker_embeddings[idx] = ebd + + model = CellTypeDataEncoder( + n_filters=n_filters, + n_heads=n_heads, + embed_size=256, + marker_embeddings=marker_embeddings, + img_feature_extractor='conv', + ) + return model + +def KRONOS(patch_size=16, num_register_tokens=0,**kwargs): + from src.models.KRONOSModel import DinoVisionTransformer + from src.utils.kronos_kit.block import NestedTensorBlock as Block + from src.utils.kronos_kit.attention import MemEffAttention + from functools import partial + + model = DinoVisionTransformer( + patch_size=patch_size, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/src/run/CellContrastEmbed.py b/src/run/CellContrastEmbed.py index 7103ed9..a8b98ca 100644 --- a/src/run/CellContrastEmbed.py +++ b/src/run/CellContrastEmbed.py @@ -1,38 +1,105 @@ -import torch -from src.data.CellContrastData import EmbedDataset -from src.models.CellContrastModel import ContrastiveLearning -from src.utils.utils import load -from src.utils.utils import set_seed - -def embed(**args): - """ - Embed visual representations of cells. - - Parameters: - image_dir (str): Path to dir in which torch.tensors of cell cut outs are - model_name (str): Path and name of model torch save dict - args (dict): Arguments - """ - - batch_size = args['batch_size'] - seed = args['seed'] - - # move to GPU (if available) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - set_seed(seed) - - train_dataset = EmbedDataset(split='train', - save_embed_data=True, - **args) - test_dataset = EmbedDataset(split='test', - save_embed_data=True, - **args) - - model = ContrastiveLearning(channels=train_dataset.img_shape[0], - **args).to(device, torch.float32) - model.load_state_dict(load(args['output_name'], save_keys='model', device=device)) - model.eval() - model.mode = 'embed' - - train_dataset.save_embed_data(model, device=device, batch_size=batch_size) - test_dataset.save_embed_data(model, device=device, batch_size=batch_size) +import torch +from src.data.CellContrastData import EmbedDataset +from src.models.CellContrastModel import ContrastiveLearning +from src.utils.utils import load +from src.utils.utils import set_seed +import os + +def embed(**args): + """ + Embed visual representations of cells. + + Parameters: + image_dir (str): Path to dir in which torch.tensors of cell cut outs are + model_name (str): Path and name of model torch save dict + args (dict): Arguments + """ + batch_size = args['batch_size'] + seed = args['seed'] + foundation_model = args['foundation_model'] + + # move to GPU (if available) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + set_seed(seed) + + if foundation_model == 'deepcell': + + from src.data.DeepCellData import PatchDataset as DeepCellDataset + train_dataset = DeepCellDataset(split='train', + save_embed_data=True, + **args) + test_dataset = DeepCellDataset(split='test', + save_embed_data=True, + **args) + + from src.models import DeepCell + model = DeepCell( + n_filters=256, # Must be 256 as load_state has 256 weights + n_heads=4, + n_domains=9, + embed_size=256, + **args + ).to(device, torch.float32) + model.load_state_dict(torch.load(args['output_name'], map_location=device)) + + elif foundation_model == 'kronos': + + from src.data.KRONOSData import KRONOSDataset + train_dataset = KRONOSDataset(split='train', + save_embed_data=True, + **args) + test_dataset = KRONOSDataset(split='test', + save_embed_data=True, + **args) + + from src.models import KRONOS + model = KRONOS( # Default arguments from KRONOS, + img_size=224, # dont touch + patch_size=16, + embed_dim=384, # Must be 384 + stride_size=16, + num_markers=512, + init_values=1.0e-05, + ffn_layer='mlp', + block_chunks=4, + num_register_tokens=16, + ).to(device, torch.float32) + model.load_state_dict(torch.load(args['output_name'], map_location=device)) + + elif foundation_model == 'eva': + + from omegaconf import OmegaConf + conf = OmegaConf.load( + os.path.join(os.getcwd(),"src","utils","eva_kit","config.yaml") + ) + + from src.data.EvaData import EvaDataset + train_dataset = EvaDataset(conf=conf, + split='train', + save_embed_data=True, + **args) + test_dataset = EvaDataset(conf=conf, + split='test', + save_embed_data=True, + **args) + from src.models.EvaModel import EvaMAE + model = EvaMAE(conf) + model.load_state_dict(torch.load(args['output_name'], map_location=device)['state_dict']) + + else: + + train_dataset = EmbedDataset(split='train', + save_embed_data=True, + **args) + test_dataset = EmbedDataset(split='test', + save_embed_data=True, + **args) + + model = ContrastiveLearning(channels=train_dataset.img_shape[0], + **args).to(device, torch.float32) + model.load_state_dict(load(args['output_name'], save_keys='model', device=device)) + model.mode = 'embed' + + model.eval() + train_dataset.save_embed_data(model, device=device, batch_size=batch_size) + test_dataset.save_embed_data(model, device=device, batch_size=batch_size) diff --git a/src/utils/deepcell_kit/__init__.py b/src/utils/deepcell_kit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/deepcell_kit/config.py b/src/utils/deepcell_kit/config.py new file mode 100644 index 0000000..04fa1d0 --- /dev/null +++ b/src/utils/deepcell_kit/config.py @@ -0,0 +1,115 @@ +import os +from pathlib import Path +import yaml +import json + +from .utils import flatten_nested_dict + + +class DCTConfig: + def __init__(self): + self.SEED = 0 + self.MAX_NUM_CHANNELS = 75 + self.BATCH_SIZE = 400 + self.MAX_CHUNK_PER_CT_PER_DATASET = 25 + self.PERCENTILE_THRESHOLD = 99.0 + + self.HIST_NORM_KERNEL_SIZE = 128 + self.CROP_SIZE = 64 + + self.STANDARD_MPP_RESOLUTION = 0.5 + + self.data_folder = Path(os.path.dirname(__file__)) / "config" + self._ct2idx, self._core_celltypes = self._load_ct2idx_and_core_celltypes() + + self._master_channels = self._load_master_channels() + + embedding_model_name = "deepseek-r1-70b-llama-distill-q4_K_M" + marker2embedding = self.get_channel_embedding( + embedding_model_name=embedding_model_name + ) + # self._domain2idx = {domain:idx for idx, domain in enumerate(sorted(set(list(self.domain_mapping.values()))))} + # self._marker2idx = {ch: idx for idx, ch in enumerate(self.master_channels)} + self._marker2idx = {ch: idx for idx, ch in enumerate(marker2embedding)} + # self._dataset2idx = {k: idx for idx, k in enumerate(self.celltype_mapping.keys())} + self.NUM_CELLTYPES = len(self.ct2idx) + # self.NUM_DOMAINS = len(self.domain2idx) + + # Default channel mapping containing all recognized marker name aliases to the + # names recognized by the model + with open(self.data_folder / "channel_mapping.yaml") as fh: + channel_mapping = yaml.safe_load(fh) + self.channel_mapping = channel_mapping + + + @property + def ct2idx(self): + return self._ct2idx + + @property + def domain2idx(self): + return self._domain2idx + + @property + def marker2idx(self): + return self._marker2idx + + @property + def dataset2idx(self): + return self._dataset2idx + + @property + def core_celltypes(self): + return self._core_celltypes + + def _load_ct2idx_and_core_celltypes(self): + with open(self.data_folder / "core_celltypes.yaml", "r") as f: + core_celltypes = yaml.safe_load(f) + + master_celltype_list = flatten_nested_dict(core_celltypes) + master_celltype_list_updated = [] + for celltype in master_celltype_list: + if celltype != "Cell": + master_celltype_list_updated.append(celltype) + + ct2idx = {ct: idx for idx, ct in enumerate(master_celltype_list_updated)} + + return ct2idx, core_celltypes + + + @property + def master_channels(self): + return self._master_channels + + def _load_master_channels(self): + with open(self.data_folder / "master_channels.yaml", "r") as f: + master_channels = yaml.load(f, Loader=yaml.FullLoader) + return master_channels + + + def get_channel_embedding(self, embedding_model_name="text-embedding-3-large-1024"): + """Get the channel embedding from the json file. + """ + with open(self.data_folder / f"marker_embeddings-{embedding_model_name}.json", "r") as f: + channel_embedding = json.load(f) + return channel_embedding + + def get_celltype_embedding(self, embedding_model_name="text-embedding-3-large-1024"): + """Get the celltype embedding from the json file. + """ + with open(self.data_folder / f"celltype_embeddings-{embedding_model_name}.json", "r") as f: + ct2embedding_dict = json.load(f) + return ct2embedding_dict + + def get_tct_mapping(self): + """Get the tissue to celltype mapping from the yaml file. + """ + with open(self.data_folder / f"tissue_celltype_mapping_merged.yaml", "r") as f: + tct = yaml.safe_load(f) + return tct + + +if __name__ == "__main__": + dct_config = DCTConfig() + + print(dct_config.__dict__) diff --git a/src/utils/deepcell_kit/config/channel_mapping.yaml b/src/utils/deepcell_kit/config/channel_mapping.yaml new file mode 100644 index 0000000..68b7faf --- /dev/null +++ b/src/utils/deepcell_kit/config/channel_mapping.yaml @@ -0,0 +1,83 @@ +Beta-catenin: Beta-catenin +CD11b: CD11b +CD11c: CD11c +CD138: CD138 +CD16: CD16 +CD163: CD163 +CD20: CD20 +CD209: CD209 +CD3: CD3 +CD31: CD31 +CD4: CD4 +CD45: CD45 +CD45RO: CD45RO +CD56: CD56 +CD63: CD63 +CD68: CD68 +CD8: CD8 +EGFR: EGFR +FoxP3: FoxP3 +HLA-Class-2: HLA-Class-2 +HLA-Class-1: HLA-Class-1 +IDO: IDO +CK17: CK17 +CK6: CK6 +Ki67: Ki67 +LAG3: LAG3 +MPO: MPO +OX40: OX40 +PDL1: PDL1 +PD1: PD1 +PanCK: PanCK +SMA: SMA +Vimentin: Vimentin +p53: p53 +ADEFENSIN5: DEFA5 +ASMA: SMA +BCL2: Bcl-2 +CD117: c-kit +CD11C: CD11c +CD123: CD123 +CD127: CD127 +CD138: CD138 +CD15: CD15 +CD16: CD16 +CD161: CD161 +CD163: CD163 +CD19: CD19 +CD206: CD206 +CD21: CD21 +CD25: CD25 +CD3: CD3 +CD31: CD31 +CD34: CD34 +CD36: CD36 +CD38: CD38 +CD4: CD4 +CD44: CD44 +CD45: CD45 +CD45RO: CD45RO +CD56: CD56 +CD57: CD57 +CD66: CD66 +CD68: CD68 +CD69: CD69 +CD7: CD7 +CD8: CD8 +CD90: CD90 +CHGA: CgA +CK7: CK7 +COLLAGENIV: Col4 +CYTOKERATIN: PanCK +FAP: FAP +HLADR: HLA-Class-2 +ITLN1: Intelectin-1 +KI67: Ki67 +MUC1: MUC1 +MUC2: MUC2 +NKG2G: NKG2D +OLFM4: OLFM4 +PODOPLANIN: PDPN +SOX9: SOX9 +SYNAPTOPHYSIN: Synaptophysin +VIMENTIN: Vimentin diff --git a/src/utils/deepcell_kit/config/core_celltypes.yaml b/src/utils/deepcell_kit/config/core_celltypes.yaml new file mode 100644 index 0000000..3fa4143 --- /dev/null +++ b/src/utils/deepcell_kit/config/core_celltypes.yaml @@ -0,0 +1,49 @@ +Cell: + Tcell: + Treg: {} + CD4T: {} + CD8T: {} + NKT: {} + Bcell: + Plasma: {} + NK: {} + Dendritic: {} + Mast: {} + Neutrophil: {} + Macrophage: + Microglial: {} + Langerhans: {} + Monocyte: {} + Epithelial: + CollectingDuct: {} + Melanocyte: {} + Goblet: {} + Paneth: {} + Enterocyte: {} + Endocrine: + AlphaCell: {} + BetaCell: {} + Endothelial: + HSEC: {} + LymphaticEndothelial: {} + BloodVesselEndothelial: {} + Stromal: + Fibroblast: + Stellate: {} + Myofibroblast: {} + SmoothMuscle: {} + Pericyte: + Mesangial: {} + CardiacMuscle: {} + Nerve: + Neuron: {} + Glial: {} + Tumor: {} + Thrombocyte: {} + Erythrocyte: {} + Hepatocyte: {} + Astrocyte: {} + EVT: {} + ICC: {} + LittoralCell: {} + Podocyte: {} diff --git a/src/utils/deepcell_kit/config/master_channels.yaml b/src/utils/deepcell_kit/config/master_channels.yaml new file mode 100644 index 0000000..5e587a0 --- /dev/null +++ b/src/utils/deepcell_kit/config/master_channels.yaml @@ -0,0 +1,177 @@ +- ACE2 +- ANXA1 +- AR +- ARG1 +- ASCT2 +- Amylase +- Bcl-2 +- Beta-catenin +- Biglycan +- CCR7 +- CD10 +- CD103 +- CD106 +- CD107a +- CD11b +- CD11c +- CD123 +- CD127 +- CD138 +- CD14 +- CD147 +- CD15 +- CD154 +- CD16 +- CD161 +- CD163 +- CD19 +- CD1c +- CD20 +- CD206 +- CD207 +- CD209 +- CD21 +- CD25 +- CD27 +- CD271 +- CD294 +- CD3 +- CD31 +- CD33 +- CD34 +- CD35 +- CD36 +- CD38 +- CD39 +- CD4 +- CD44 +- CD45 +- CD45RA +- CD45RO +- CD49a +- CD49f +- CD5 +- CD54 +- CD56 +- CD57 +- CD61 +- CD63 +- CD66 +- CD68 +- CD69 +- CD7 +- CD73 +- CD79a +- CD8 +- CD80 +- CD8a +- CD90 +- CD94 +- CD98 +- CDX2 +- CK10 +- CK14 +- CK17 +- CK19 +- CK5 +- CK6 +- CK7 +- COL1 +- COX2 +- CPT1A +- CXCL5 +- CXCR5 +- Calprotectin +- CgA +- Col4 +- DEFA5 +- E-cadherin +- EGFR +- ER +- EpCAM +- FAP +- FASN +- Fibronectin +- FoxP3 +- GATA3 +- GATA6 +- GFAP +- GLUT1 +- GRB +- GYPA +- GZMB +- Galectin9 +- HER2 +- HIF1A +- HK1 +- HLA-Class-1 +- HLA-Class-2 +- HLA-G +- HO1 +- ICOS +- IDO +- IFNg +- IL6 +- IgA +- IgD +- IgM +- Intelectin-1 +- Ki67 +- LAG3 +- Lck +- Lumican +- Lysozyme +- Lyve1 +- MCT1 +- MECA-79 +- MMP7 +- MMP9 +- MPO +- MUC1 +- MUC2 +- MUC5AC +- MUC6 +- MelanA +- NKG2D +- NRF2 +- Na-K-ATPase +- Nestin +- OLFM4 +- OLIG2 +- OX40 +- P-cadherin +- P2Y12 +- P63 +- PD1 +- PDL1 +- PDPN +- PGA3 +- PKM2 +- PMEL +- PROX1 +- PanCK +- S100A4 +- SMA +- SOMATOSTATIN +- SOX10 +- SOX2 +- SOX9 +- SPARC +- Synaptophysin +- TCF1 +- TIGIT +- TIM3 +- TMPRSS2 +- TTF1 +- Tox-Tox2 +- Tryptase +- Vimentin +- XBP1 +- YAP +- c-Caspase-3 +- c-kit +- gH2AX +- gdTCR +- iNOS +- p16 +- p53 diff --git a/src/utils/deepcell_kit/config/tissue_celltype_mapping_merged.yaml b/src/utils/deepcell_kit/config/tissue_celltype_mapping_merged.yaml new file mode 100644 index 0000000..1b31b52 --- /dev/null +++ b/src/utils/deepcell_kit/config/tissue_celltype_mapping_merged.yaml @@ -0,0 +1,420 @@ +Barretts_Esophagus: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endocrine +- Enterocyte +- Epithelial +- Fibroblast +- Goblet +- ICC +- Macrophage +- Myofibroblast +- NK +- Nerve +- Neuron +- Neutrophil +- Paneth +- Plasma +- SmoothMuscle +- Tcell +- Treg +- Tumor +Breast: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endothelial +- Epithelial +- Fibroblast +- Macrophage +- Mast +- Monocyte +- Myofibroblast +- NK +- NKT +- Nerve +- Neuron +- Neutrophil +- Pericyte +- Plasma +- SmoothMuscle +- Tcell +- Treg +- Tumor +GI: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endocrine +- Endothelial +- Enterocyte +- Epithelial +- Fibroblast +- Goblet +- ICC +- LymphaticEndothelial +- Macrophage +- Mast +- Monocyte +- NK +- Nerve +- Neuron +- Neutrophil +- Paneth +- Plasma +- SmoothMuscle +- Tcell +- Treg +- Tumor +Heart: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- CardiacMuscle +- Dendritic +- Endothelial +- Fibroblast +- Glial +- Macrophage +- Monocyte +- Myofibroblast +- Neutrophil +- Neuron +- NK +- NKT +- Pericyte +- Plasma +- SmoothMuscle +- Stromal +Jejunum: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endocrine +- Endothelial +- Enterocyte +- Epithelial +- Fibroblast +- Goblet +- LymphaticEndothelial +- Macrophage +- Mast +- Monocyte +- Neuron +- Neutrophil +- NK +- Paneth +- Plasma +- SmoothMuscle +- Tcell +- Tumor +Kidney: +- BloodVesselEndothelial +- CD4T +- CD8T +- CollectingDuct +- Dendritic +- Endothelial +- Epithelial +- Fibroblast +- LymphaticEndothelial +- Macrophage +- Mast +- Mesangial +- Monocyte +- Myofibroblast +- NK +- NKT +- Neutrophil +- Podocyte +- SmoothMuscle +- Tcell +- Treg +- Tumor +Liver: +- BloodVesselEndothelial +- CD4T +- CD8T +- Endothelial +- Epithelial +- Fibroblast +- Hepatocyte +- HSEC +- Macrophage +- Myofibroblast +- NKT +- Plasma +- SmoothMuscle +- Stellate +- Tcell +- Treg +- Tumor +Lung: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endothelial +- Epithelial +- Fibroblast +- Goblet +- LymphaticEndothelial +- Macrophage +- Mast +- Monocyte +- Myofibroblast +- NK +- NKT +- Nerve +- Neuron +- Neutrophil +- Pericyte +- SmoothMuscle +- Tcell +- Thrombocyte +- Treg +- Tumor +Lymph_Node: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Epithelial +- Fibroblast +- LymphaticEndothelial +- Macrophage +- Mast +- Monocyte +- Myofibroblast +- NK +- NKT +- Neutrophil +- Plasma +- SmoothMuscle +- Tcell +- Treg +- Tumor +Lymphnode: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Epithelial +- Fibroblast +- LymphaticEndothelial +- Macrophage +- Mast +- Monocyte +- Myofibroblast +- NK +- NKT +- Neutrophil +- Plasma +- SmoothMuscle +- Tcell +- Treg +- Tumor +Musculoskeletal: +- BloodVesselEndothelial +- CD8T +- Fibroblast +- Nerve +- SmoothMuscle +- Tcell +- Treg +- Tumor +Nervous: +- Astrocyte +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endocrine +- Endothelial +- Fibroblast +- Glial +- Macrophage +- Mast +- Microglial +- Monocyte +- Myofibroblast +- NK +- Nerve +- Neuron +- Neutrophil +- Pericyte +- SmoothMuscle +- Tcell +- Treg +- Tumor +Pancreas: +- AlphaCell +- BetaCell +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endocrine +- Endothelial +- Epithelial +- Fibroblast +- Macrophage +- Monocyte +- Myofibroblast +- NK +- Neuron +- Neutrophil +- SmoothMuscle +- Stromal +- Tcell +- Treg +- Tumor +Renal: +- BloodVesselEndothelial +- CD4T +- CD8T +- CollectingDuct +- Dendritic +- Endothelial +- Epithelial +- Fibroblast +- LymphaticEndothelial +- Macrophage +- Mast +- Mesangial +- Monocyte +- Myofibroblast +- NK +- NKT +- Neutrophil +- Podocyte +- SmoothMuscle +- Tcell +- Treg +- Tumor +Reproductive: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- EVT +- Endothelial +- Epithelial +- Fibroblast +- Macrophage +- Mast +- Monocyte +- Myofibroblast +- NK +- NKT +- Nerve +- Neuron +- Neutrophil +- Pericyte +- SmoothMuscle +- Tcell +- Treg +Skin: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endothelial +- Epithelial +- Fibroblast +- Langerhans +- LymphaticEndothelial +- Macrophage +- Melanocyte +- Myofibroblast +- NK +- NKT +- Nerve +- Neuron +- Neutrophil +- Pericyte +- Plasma +- SmoothMuscle +- Tcell +- Treg +- Tumor +Spleen: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Erythrocyte +- Fibroblast +- LittoralCell +- LymphaticEndothelial +- Macrophage +- Monocyte +- Myofibroblast +- NK +- NKT +- Neutrophil +- Plasma +- SmoothMuscle +- Tcell +- Tumor +Thymus: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endothelial +- Epithelial +- Fibroblast +- LymphaticEndothelial +- Macrophage +- NKT +- Neutrophil +- Plasma +- SmoothMuscle +- Tcell +- Treg +Tonsil: +- Bcell +- BloodVesselEndothelial +- CD4T +- CD8T +- Dendritic +- Endocrine +- Endothelial +- Epithelial +- Fibroblast +- Goblet +- LymphaticEndothelial +- Macrophage +- Mast +- Monocyte +- Myofibroblast +- NK +- NKT +- Neuron +- Neutrophil +- Paneth +- Plasma +- SmoothMuscle +- Tcell +- Treg +- Tumor \ No newline at end of file diff --git a/src/utils/deepcell_kit/image_funcs.py b/src/utils/deepcell_kit/image_funcs.py new file mode 100644 index 0000000..f7c8cc8 --- /dev/null +++ b/src/utils/deepcell_kit/image_funcs.py @@ -0,0 +1,455 @@ +import numpy as np +from skimage.transform import rescale +import os +import pandas as pd +import torch +import torch.nn.functional as F +import warnings +import dask +import dask.array as da +from dask import delayed +import re +from dask.diagnostics import ProgressBar +from src.utils.image_preprocess import load_img +from src.utils.deepcell_kit.config import DCTConfig + +def pad_cell(X: np.ndarray, y: np.ndarray, crop_size: int): + delta = crop_size // 2 + X = np.pad(X, ((delta, delta), (delta, delta), (0, 0))) + y = np.pad(y, ((delta, delta), (delta, delta))) + return X, y + +def get_crop_box(centroid, delta): + minr = int(centroid[0]) - delta + maxr = int(centroid[0]) + delta + minc = int(centroid[1]) - delta + maxc = int(centroid[1]) + delta + return np.array([minr, minc, maxr, maxc]) + +def get_neighbor_masks(mask_patch, cbox, cell_id): + """Returns binary masks of a cell and its neighbors. This function expects padding around + the edges, and will throw an error if you hit a wrap around.""" + minr, minc, maxr, maxc = cbox + assert np.issubdtype(mask_patch.dtype, np.integer) and isinstance(cell_id, int) + + binmask_cell = (mask_patch == cell_id).astype(np.int32) + + binmask_neighbors = (mask_patch != cell_id).astype(np.int32) * ( + mask_patch != 0 + ).astype(np.int32) + return binmask_cell, binmask_neighbors + +# Dead code? +#def combine_raw_mask(raw, mask): +# raw_aug_mask = np.concatenate( +# [ +# np.expand_dims(raw, axis=-1), # (N, C_new, H, W, 1) +# np.tile( +# np.expand_dims(mask, axis=1), (1, raw.shape[1], 1, 1, 1) +# ), # (N, C_new, H, W, 2) +# ], +# axis=-1, +# ) # (N, C_new, H, W, 3) +# +# return raw_aug_mask + +def combine_masks(raw, mask): + mask = np.swapaxes(mask, 0, 2) # (2, H, W) + mask = np.expand_dims(mask, axis=0) # (1, 2, H, W) + raw_aug_mask = np.concatenate( + [ + np.expand_dims(raw, axis=1), # (C, 1, H, W) + np.tile(mask, (raw.shape[0], 1, 1, 1)), # (C, 2, H, W) + ], + axis=1, + ) # (C, 3, H, W) + return raw_aug_mask + +def create_attn_mask(sample, max_channels): + # True = padding + # https://pytorch.org/docs/stable/generated/torch.ao.nn.quantizable.MultiheadAttention.html#torch.ao.nn.quantizable.MultiheadAttention.forward + mask = np.full((max_channels), True) + mask[0 : sample.shape[0]] = False + + return mask + +def pad_images(sample, max_channels): + paddings = -1.0 # retrieved as a constant from repo (?) + return np.pad( + sample, + ((0, max_channels - sample.shape[0]), (0, 0), (0, 0), (0, 0)), + mode="constant", + constant_values=paddings, + ) + +def get_channel_masking(channel_names, channel_mapping): + if len(channel_names) == 0: + print('Warning! channel_names is empty, all channels will be masked out!') + channel_names_standard = [] + channel_masking = [] + for ch_name in channel_names: + if ch_name not in channel_mapping: + channel_masking.append(True) + warnings.warn( + f"Channel {ch_name} is not in the channel mapping. " + "This channel will be masked out." + ) + else: + channel_masking.append(False) + channel_names_standard.append(channel_mapping[ch_name]) + return channel_masking, channel_names_standard + +def get_ch_idx(channel_names_standard, marker2idx, max_channels): + return torch.as_tensor( + [marker2idx[ch_name] for ch_name in channel_names_standard] + + [-1] * (max_channels - len(channel_names_standard)) + ) # (C_max, ) + +def normalize_per_channel(image, min_vals, ptp_vals): + return (image-min_vals)/ptp_vals + +def percentile_threshold(image, img_max, percentile=99.9): + """Copied and modified from: https://github.com/vanvalenlab/deepcell-toolbox/blob/e8c1277ee4243bc6a34916d554d0c2eab0cf7505/deepcell_toolbox/processing.py#L104 + Threshold an image to reduce bright spots + + Args: + image: numpy array of image data + percentile: cutoff used to threshold image + + Returns: + np.array: thresholded version of input image + """ + + processed_image = np.zeros_like(image) + for chan in range(image.shape[-1]): + current_img = np.copy(image[..., chan]) + non_zero_vals = current_img[np.nonzero(current_img)] + # only threshold if channel isn't blank + if len(non_zero_vals) > 0: + # threshold values down to max + threshold_mask = current_img > img_max[chan] + current_img[threshold_mask] = img_max[chan] + + # update image + processed_image[..., chan] = current_img + + return processed_image + +def get_cell_ids(mask, df): + + xs = df["Centroid.X.px"].round().astype(int).to_numpy() + ys = df["Centroid.Y.px"].round().astype(int).to_numpy() + if "cell_ID" in df.columns: + cell_ids = df["cell_ID"].astype(int).to_numpy() + else: + cell_ids = np.empty(len(xs),dtype=np.uint16) + for i in range(len(xs)): + cell_id = int(mask[ys[i],xs[i]]) + if cell_id == 0: + raise ValueError('Warning! A centroid points outside of cell') + cell_ids[i] = cell_id + + return cell_ids + +def process(raw, mask, cell_cutout, mpp, properties, dct_config): + + raw = np.transpose(raw, (1, 2, 0)) # (H, W, C) + raw = rescale(raw, mpp / dct_config.STANDARD_MPP_RESOLUTION, preserve_range=True, channel_axis=-1) + + mask = rescale( + mask, + mpp / dct_config.STANDARD_MPP_RESOLUTION, + order=0, + preserve_range=True, + anti_aliasing=False, + ).astype(np.int32) + + min_vals, ptp_vals, img_max, _, _ = properties + raw = percentile_threshold(raw, img_max, percentile=dct_config.PERCENTILE_THRESHOLD) + raw = normalize_per_channel(raw, min_vals, ptp_vals) + raw, mask = pad_cell(raw, mask, cell_cutout) + + return raw, mask + +@delayed +def per_image_hist(img_path, nbins, vmin, vmax, channel_masking, mpp): + img = load_img(img_path, "").astype(np.float32) + img = img[~np.array(channel_masking), :, :] + img = np.transpose(img, (1, 2, 0)) + + dct_config = DCTConfig() + img = rescale(img, mpp / dct_config.STANDARD_MPP_RESOLUTION, + preserve_range=True, channel_axis=-1) + + C = img.shape[-1] + H = np.zeros((C, nbins), dtype=np.int64) + + for c in range(C): + vals = img[..., c] + vals = vals[vals != 0] + if vals.size: + h, _ = np.histogram(vals, bins=nbins, range=(vmin[c], vmax[c])) + H[c] = h + + return H # (C, nbins) + +def build_global_hist(img_paths, nbins, vmin, vmax, channel_masking, mpp): + C = len(vmin) + tasks = [per_image_hist(p, nbins, vmin, vmax, channel_masking, mpp) for p in img_paths] + + H = da.stack( + [da.from_delayed(t, shape=(C, nbins), dtype=np.int64) for t in tasks], + axis=0 + ) # (N, C, nbins) + + return H.sum(axis=0) # (C, nbins) + +def hist_to_percentile(global_hist, percentile, vmin, vmax): + """ + https://stackoverflow.com/questions/10640759/how-to-get-the-cumulative-distribution-function-with-numpy + """ + cdf = np.cumsum(global_hist, axis=1) + totals = cdf[:, -1] + + out = np.full((global_hist.shape[0],), np.nan, dtype=np.float32) + ok = totals > 0 + + cdf_ok = (cdf[ok].T / totals[ok]).T + idx = np.argmax(cdf_ok >= (percentile / 100.0), axis=1) + + bin_width = (vmax[ok] - vmin[ok]) / global_hist.shape[1] + out[ok] = (vmin[ok] + idx * bin_width).astype(np.float32) + return out + +@delayed +def get_properties(img_path, channel_masking, mpp): + img_data = load_img(img_path,'').astype(np.float32)[~np.array(channel_masking), :, :] + img_data = np.transpose(img_data, (1, 2, 0)) + dct_config = DCTConfig() + img_data = rescale(img_data, mpp / dct_config.STANDARD_MPP_RESOLUTION, + preserve_range=True, channel_axis=-1) + min_vals = np.min(img_data, axis=(0, 1), keepdims=True) + max_vals = np.max(img_data, axis=(0, 1), keepdims=True) + return min_vals, max_vals + +@delayed +def get_true_max(img_path, channel_masking, img_max, percentile, mpp): + img_data = load_img(img_path,'').astype(np.float32)[~np.array(channel_masking), :, :] + img_data = np.transpose(img_data, (1, 2, 0)) + dct_config = DCTConfig() + img_data = rescale(img_data, mpp / dct_config.STANDARD_MPP_RESOLUTION, + preserve_range=True, channel_axis=-1) + img_data = percentile_threshold(img_data, img_max, percentile) + max_vals = np.max(img_data, axis=(0, 1), keepdims=True) + return max_vals + +def get_global_properties(img_paths, + channel_masking, + percentile, + standard_mpp, + cell_cutout, + mpp): + + tasks = [get_properties(p, channel_masking, mpp) for p in img_paths] + results = dask.compute(*tasks) + mins = [mn for (mn, mx) in results] + maxs = [mx for (mn, mx) in results] + + min_vals = np.min(np.stack(mins, axis=0), axis=0) + max_vals = np.max(np.stack(maxs, axis=0), axis=0) + + # This can be increased for higher accuracy, but + # might run slower. 4096 is typically fine + NBINS = 4096 + global_hist_dask = build_global_hist(img_paths, NBINS, np.squeeze(min_vals), np.squeeze(max_vals), channel_masking, mpp) + global_hist = global_hist_dask.compute() + img_max = hist_to_percentile(global_hist, percentile=percentile, vmin=np.squeeze(min_vals), vmax=np.squeeze(max_vals)) + + tasks = [get_true_max(p, channel_masking, img_max, percentile, mpp) for p in img_paths] + results = dask.compute(*tasks) + maxs = [mx for mx in results] + max_vals = np.max(np.stack(maxs, axis=0), axis=0) + ptp_vals = max_vals - min_vals + + pad = cell_cutout // 2 + scale = mpp/standard_mpp + + return min_vals, ptp_vals, img_max, pad, scale + +def patch_generator(raw, + mask, + df, + cell_ids, + cell_cutout, + properties, + dct_config): + + _, _, _, pad, scale = properties + xs = np.round(df['Centroid.X.px'].values*scale+pad).astype(int) + ys = np.round(df['Centroid.Y.px'].values*scale+pad).astype(int) + + for i in range(len(xs)): + y, x = ys[i], xs[i] + cell_id = int(cell_ids[i]) + + delta = cell_cutout // 2 + cbox = get_crop_box((y, x), delta) + minr, minc, maxr, maxc = cbox + raw_patch = raw[minr:maxr, minc:maxc, :] # (H, W, C) + mask_patch = mask[minr:maxr, minc:maxc] + + self_mask, neighbor_mask = get_neighbor_masks( + mask_patch, cbox, cell_id + ) # (H, W), (H, W) + + raw_patch = np.transpose(raw_patch, (2, 0, 1)) # (C, H, W) + mask_patch = np.stack([self_mask, neighbor_mask], axis=-1) # (H, W, 2) + + yield raw_patch, mask_patch.astype(np.float32), i + +@delayed +def cell_seg(img_path, + mask_path, + df, + cell_cutout, + properties, + channel_masking, + mpp, + dct_config): + + raw = load_img(img_path,'').astype(np.float32)[~np.array(channel_masking), :, :] + mask = np.squeeze(load_img(mask_path,'')).astype(np.uint32) + cell_ids = get_cell_ids(mask, df) + raw, mask = process(raw, + mask, + cell_cutout, + mpp, + properties, + dct_config) + + cell_results = torch.empty(( + df.shape[0], + raw.shape[2], + 3, + 64, + 64 + ), dtype=torch.float32) + + for sample_patch, mask_patch, i in patch_generator(raw, + mask, + df, + cell_ids, + cell_cutout, + properties, + dct_config): + cell = torch.as_tensor( + combine_masks(sample_patch, mask_patch), # (C, 3, H, W) + dtype=torch.float32 + ) + if cell_cutout != 64: + cell = F.interpolate( + cell, + size=(64,64), + mode="bilinear", + align_corners=False + ) + cell_results[i] = cell + + np.save(os.path.join(img_path.split('.ome.tif')[0]+'_cells.npy'), cell_results.numpy()) + del cell_results + del raw + del mask + del sample_patch + del mask_patch + del cell_ids + +def extract_idx(path): + stem = path.split(os.sep)[-1] + if '.ome.tif' in stem: + stem = stem.split('.ome.tif')[0] + elif '_mask.tif' in stem: + stem = stem.split('_mask.tif')[0] + elif '_mask.npy' in stem: + stem = stem.split('_mask.npy')[0] + return int(stem.rsplit('_', 1)[1]) + +def image_preprocess(path, + channel_names='', + cell_cutout=64, + mpp=0.4, + batch_size=1, + ids_path=""): + + mask_dir = os.path.join(os.getcwd(),path,'masks') + train_dir = os.path.join(os.getcwd(),path,'train') + test_dir = os.path.join(os.getcwd(),path,'test') + + dct_config = DCTConfig() + channel_masking, channel_names_standard = get_channel_masking(channel_names, + channel_mapping=dct_config.channel_mapping) + ch_idx = get_ch_idx(channel_names_standard, + marker2idx=dct_config.marker2idx, + max_channels=dct_config.MAX_NUM_CHANNELS) + np.save(os.path.join(path, 'channel_idx.npy'), ch_idx.numpy()) + del ch_idx + + if ids_path != "": + df = pd.read_csv(os.path.join(os.getcwd(),"data","raw",ids_path)) + else: + df = pd.read_csv([os.path.abspath(os.path.join(path, p)) for p in os.listdir(path) if p.lower().endswith(('csv'))][0]) + + img_names = list(set( + [re.sub(r'_\d+$', '', p.split(os.sep)[-1].split('.ome.tif')[0]) for p in df['Image'].tolist()] + )) + + total_rois = len(os.listdir(mask_dir)) + current_total = 0 + for img_name in img_names: + img_paths = sorted(([os.path.join(train_dir,p) + for p in os.listdir(train_dir) + if p.startswith(img_name) and not p.endswith(('.npy','.pt'))] + + ([os.path.join(test_dir,p) + for p in os.listdir(test_dir) + if p.startswith(img_name) and not p.endswith(('.npy','.pt'))])), + key=extract_idx) + mask_paths = sorted([os.path.join(mask_dir,p) + for p in os.listdir(mask_dir) + if p.startswith(img_name)], + key=extract_idx) + + if len(img_paths) == 0: + print(f">>> Warning! No matching images for {img_name}") + continue + if len(mask_paths) == 0: + print(f">>> Warning! No matching masks for {img_name}") + continue + + properties = get_global_properties(img_paths, + channel_masking, + dct_config.PERCENTILE_THRESHOLD, + dct_config.STANDARD_MPP_RESOLUTION, + cell_cutout, + mpp) + + num_batches = (len(img_paths)+batch_size-1) // batch_size + print(f'>> Processing {img_name}, {len(img_paths)} ROIs in batches of {batch_size}') + for batch_idx in range(num_batches): + if batch_idx < num_batches - 1: + start = batch_idx*batch_size + end = batch_idx*batch_size+batch_size + else: + start = batch_idx*batch_size + end = len(img_paths) + tasks = [cell_seg(img_paths[p], + mask_paths[p], + df[df["Image"]==img_paths[p].split(os.sep)[-1]], + cell_cutout, + (properties), + channel_masking, + mpp, + dct_config) + for p in range(start,end)] + with ProgressBar(): + dask.compute(*tasks) + current_total += len(img_paths) + print(f'{current_total}/{total_rois} ROIs processed') diff --git a/src/utils/deepcell_kit/utils.py b/src/utils/deepcell_kit/utils.py new file mode 100644 index 0000000..23f971c --- /dev/null +++ b/src/utils/deepcell_kit/utils.py @@ -0,0 +1,105 @@ +import numpy as np +import yaml +from pathlib import Path + + +def flatten_nested_dict(nested_dict): + flattened = [] + for key, value in nested_dict.items(): + if value: + flattened.append(key) + flattened.extend(flatten_nested_dict(value)) + else: + flattened.append(key) + return list(sorted(set(flattened))) + + +def get_ct_ch_across_files(dataset_path, keyword=None): + core_dict = {} + core_channels = [] + for file in dataset_path.iterdir(): + if "npz.dvc" in file.name: + meta_file = file + + if keyword is not None: + if not file.name.startswith(keyword): + continue + + with open(meta_file) as f: + meta_info = yaml.load(f, Loader=yaml.FullLoader) + + try: + celltype_mapper = meta_info["meta"]["file_contents"]["cell_types"][ + "mapper" + ] + except KeyError: + print(f"cell type mapper not found in {file.name}") + continue + + channels = [ + item["target"] for item in meta_info["meta"]["sample"]["channels"] + ] + + if core_dict == {}: + core_dict = celltype_mapper + else: + assert ( + celltype_mapper == core_dict + ), f"celltype mapper is not the same across all files, {file.name}" + + if core_channels == []: + core_channels = channels + else: + assert ( + core_channels == channels + ), f"channels are not the same across all files, {file.name}" + + return core_dict, core_channels + + +def choose_channels(channel_names, channel_mapping): + channel_mask = [] + channel_names_updated = [] + for ch in channel_names: + if ch in channel_mapping["channels_kept"]: + channel_mask.append(True) + channel_names_updated.append(channel_mapping["channels_kept"][ch]) + elif ch in channel_mapping["channels_dropped"]: + channel_mask.append(False) + else: + raise ValueError(f"Channel name {ch} not found in channel_mapping.yaml") + + channel_slices = np.where(channel_mask)[0] + + return channel_slices, channel_names_updated + + + +def create_marker_positivity_mask(unique_cell_types, dataset_name, channel_list, padding_length, dct_config): + marker_positivity_mask_dict = {} + for orig_ct in unique_cell_types: + ct = dct_config.celltype_mapping[dataset_name][orig_ct] + positive_channels = dct_config.positivity_mapping.get(ct, [0]) + positive_channels_dataset_specific = [] + if dataset_name in dct_config.positivity_mapping_dataset_specific: + tissue_marker_pos_dict = dct_config.positivity_mapping_dataset_specific[ + dataset_name + ] + if orig_ct in tissue_marker_pos_dict: + positive_channels_dataset_specific = tissue_marker_pos_dict[orig_ct] + + marker_positivity = [ + True + if ch in positive_channels or ch in positive_channels_dataset_specific + else False + for ch in channel_list + ] + [False] * padding_length + marker_positivity = np.array(marker_positivity, dtype=np.int32) + marker_positivity_mask_dict[ct] = marker_positivity + + return marker_positivity_mask_dict + + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/src/utils/download_utils/__init__.py b/src/utils/download_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/download_utils/dct_download.py b/src/utils/download_utils/dct_download.py new file mode 100644 index 0000000..a42352c --- /dev/null +++ b/src/utils/download_utils/dct_download.py @@ -0,0 +1,195 @@ +import os +import requests +from pathlib import Path +from hashlib import md5 +from tqdm import tqdm +import logging +import tarfile +import zipfile +import torch + +_latest = "2025-06-09" +_model_registry = { + # Original model version uploaded with preprint + "specific_ct_v0.1": "e499da92509821161be88a47237960a9", + # Versions released June 9th 2025. The public-data-only version is trained + # only on the subset of data that is publicly available (for reproducibility). + # Users are recommended to use the *non* public-data-only option. + "2025-06-09": "19b669675c06816414e8677f542ff542", + "2025-06-09_public-data-only": "19b669675c06816414e8677f542ff542", +} + +"""User interface to authentication layer for data/models.""" + +_api_endpoint = "https://users.deepcell.org/api/getData/" +_asset_location = Path.home() / ".deepcell" + +def dct_download(download_location,*, version=None): + """Download the deepcell-types model for local use. + + Parameters + ---------- + version : str, optional + Which version of the model to download. Default is `None`, which results + in the latest (i.e. most-recently-released) version being downloaded. + """ + + version = version if version is not None else _latest + asset_key = f"models/deepcell-types_{version}.pt" + + fetch_data( + asset_key, cache_subdir="models", file_hash=_model_registry.get(version), + download_location=download_location) + + dct_modify(download_location) + +def dct_modify(model_path): + + WEIGHT_PATH = os.path.join(model_path,'deepcell-types_2025-06-09.pt') + weights = torch.load(WEIGHT_PATH, + map_location=torch.device('cpu') + ) + + SUB_STRINGS = ['image_encoder.'] + for SUB_STRING in SUB_STRINGS: + weights = { + k.replace(SUB_STRING,''): v + for k,v in weights.items() + } + + KEYS = ["logit_scale", "ct_embedding.weight", "image_adaptor.0.weight", + "image_adaptor.0.bias", "text_adaptor.weight", "text_adaptor.bias", + "classification_head."] + for key in list(weights.keys()): + for KEY_NAME in KEYS: + if KEY_NAME in key: + weights.pop(key) + + modified_path = os.path.join(model_path, "deepcell_types_modified.pt") + torch.save(weights, modified_path) + + if os.path.exists(os.path.join(model_path,"deepcell-types_2025-06-09.pt")): + os.remove(os.path.join(model_path,"deepcell-types_2025-06-09.pt")) + +def fetch_data(asset_key: str, cache_subdir=None, file_hash=None, download_location=None): + """Fetch assets through users.deepcell.org authentication system. + + Download assets from the deepcell suite of datasets and models which + require user-authentication. + + .. note:: + + You must have a Deepcell Access Token set as an environment variable + with the name ``DEEPCELL_ACCESS_TOKEN`` in order to access assets. + + Access tokens can be created at _ + + Args: + :param asset_key: Key of the file to download. + The list of available assets can be found on the users.deepcell.org + homepage. + + :param cache_subdir: `str` indicating directory relative to + `~/.deepcell` where downloaded data will be cached. The default is + `None`, which means cache the data in `~/.deepcell`. + + :param file_hash: `str` represented the md5 checksum of datafile. The + checksum is used to perform data caching. If no checksum is provided or + the checksum differs from that found in the data cache, the data will + be (re)-downloaded. + """ + logging.basicConfig(level=logging.INFO) + +# if cache_subdir is not None: +# download_location /= cache_subdir +# download_location.mkdir(exist_ok=True, parents=True) + + # Extract the filename from the asset_key, which can be a full path + fname = os.path.split(asset_key)[-1] +# fpath = download_location / fname + fpath = os.path.join(download_location, fname) + + # Check for cached data + if file_hash is not None: + logging.info('Checking for cached data') + try: + with open(fpath, "rb") as fh: + hasher = md5(fh.read()) + logging.info(f"Checking {fname} against provided file_hash...") + md5sum = hasher.hexdigest() + if md5sum == file_hash: + logging.info( + f"{fname} with hash {file_hash} already available." + ) + return fpath + logging.info( + f"{fname} with hash {file_hash} not found in {download_location}" + ) + except FileNotFoundError: + pass + + # Check for access token + access_token = os.environ.get("DEEPCELL_ACCESS_TOKEN") + if access_token is None: + raise ValueError( + "\nDEEPCELL_ACCESS_TOKEN not found.\n" + "Please set your access token to the DEEPCELL_ACCESS_TOKEN\n" + "environment variable.\n" + "For example:\n\n" + "\texport DEEPCELL_ACCESS_TOKEN=.\n\n" + "If you don't yet have a token, you can create one at\n" + "https://users.deepcell.org" + ) + + # Request download URL + headers = {"X-Api-Key": access_token} + logging.info("Making request to server") + resp = requests.post( + _api_endpoint, headers=headers, data={"s3_key": asset_key} + ) + # Raise informative exception for the specific case when the asset_key is + # not found in the bucket + if resp.status_code == 404 and resp.json().get("error") == "Key not found": + raise ValueError(f"Object {asset_key} not found.") + # Raise informative exception for the specific case when an invalid + # API token is provided. + if resp.status_code == 403 and ( + resp.json().get("detail") == "Authentication credentials were not provided." + ): + raise ValueError( + f"\n\nAPI token {access_token} is not valid.\n" + "The token may be expired - if so, create a new one at\n" + "https://users.deepcell.org" + ) + # Handle all other non-http-200 status + resp.raise_for_status() + + # Parse response + response_data = resp.json() + download_url = response_data["url"] + file_size = response_data["size"] + # Parse file_size (TODO: would be more convenient if it were numerical, i.e. always bytes) + val, suff = file_size.split(" ") + # TODO: Case statement would be awesome here, but need to support all the + # way back to Python 3.8 + suffix_mapping = {"KB": 2**10, "MB": 2**20, "B": 1, "GB": 2**30, "TB": 2**40} + file_size_numerical = int(float(val) * suffix_mapping[suff]) + + logging.info( + f"Downloading {asset_key} with size {file_size} to {download_location}" + ) + data_req = requests.get( + download_url, headers={"user-agent": "Wget/1.20 (linux-gnu)"}, stream=True + ) + data_req.raise_for_status() + + chunk_size = 4096 + with tqdm.wrapattr( + open(fpath, "wb"), "write", miniters=1, total=file_size_numerical + ) as fh: + for chunk in data_req.iter_content(chunk_size=chunk_size): + fh.write(chunk) + + logging.info(f"🎉 Successfully downloaded file to {fpath}") + + return fpath diff --git a/src/utils/download_utils/kr_download.py b/src/utils/download_utils/kr_download.py new file mode 100644 index 0000000..2ba6ce3 --- /dev/null +++ b/src/utils/download_utils/kr_download.py @@ -0,0 +1,45 @@ +import torch +import os +from typing import Optional, Tuple +from huggingface_hub import hf_hub_download +import shutil + +def create_model_from_pretrained( + checkpoint_path: Optional[str] = None, + cfg_path: Optional[str] = None, + cache_dir: Optional[str] = None, + hf_auth_token: Optional[str] = None, + cfg: Optional[dict] = None, +) -> Tuple[torch.nn.Module, torch.dtype, int]: + + checkpoint_filename = "kronos_vits16_model.pt" + + # Download checkpoint from Hugging Face Hub + checkpoint_path = hf_hub_download( + checkpoint_path[len("hf_hub:"):], + cache_dir=cache_dir, + filename=checkpoint_filename, + token=hf_auth_token + ) + + # Load the state dictionary, removing specific prefixes and entries + state_dict = torch.load(checkpoint_path, map_location='cpu') + state_dict = state_dict['teacher'] + state_dict = {k.replace('backbone.', ''): v for k, v in state_dict.items()} + state_dict = {k: v for k, v in state_dict.items() if 'dino_head' not in k} + + print(f"\033[92mLoaded model weights from {checkpoint_path}\033[0m") + + return state_dict + +def kr_download(model_path): + + state_dict = create_model_from_pretrained( + checkpoint_path="hf_hub:MahmoodLab/kronos", # Make sure you have requested access on HuggingFace + cache_dir="./model_assets", + ) + torch.save(state_dict, os.path.join(model_path,"kronos_weights.pt")) + if os.path.exists(os.path.join(model_path,"model_assets")): + shutil.rmtree(os.path.join(model_path,"model_assets")) + if os.path.exists(os.path.join(model_path,".locks")): + shutil.rmtree(os.path.join(model_path,".locks")) diff --git a/src/utils/eva_kit/__init__.py b/src/utils/eva_kit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/eva_kit/config.yaml b/src/utils/eva_kit/config.yaml new file mode 100644 index 0000000..4e50c46 --- /dev/null +++ b/src/utils/eva_kit/config.yaml @@ -0,0 +1,27 @@ +# Eva Model Configuration +ds: # dataset + patch_size: 224 + token_size: 8 + marker_dim: 3072 + mask_strategy: "random" # Required for model initialization + mask_ratio: 0. # Required for model initialization, set to 0 for tasks + +cm: # channel model + dim: 512 + mlp_ratio: 4 + n_heads: 4 + n_layers: 2 + +pm: # patch model + dim: 768 + mlp_ratio: 4 + n_heads: 12 + n_layers: 12 + out_dim: 512 + +de: # decoder model + dim: 512 + marker_dim: 512 + mlp_ratio: 4 + n_heads: 16 + n_layers: 8 diff --git a/src/utils/eva_kit/constant.py b/src/utils/eva_kit/constant.py new file mode 100644 index 0000000..7bca36a --- /dev/null +++ b/src/utils/eva_kit/constant.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2025/05/14 22:46 +@Author : Yufan Liu +@Desc : Biomarker-Gene mapping. +""" + +marker_to_gene = { + "CD38": "CD38", + "S100A4": "S100A4", + "CD127": "IL7R", + "SOX2": "SOX2", + "F4/80": "ADGRE1", + "CD4": "CD4", + "Geminin": "GMNN", + "CD8": "CD8A", + "HistoneH3p": "H3F3A", + "pS6": "RPS6", + "CD227": "MUC1", + "PDL2": "PDCD1LG2", + "CD140b": "PDGFRB", + "CD117": "KIT", + "Perforin": "PRF1", + "ATM": "ATM", + "TCF1": "TCF7", + "PGP9.5": "UCHL1", + "CD39": "ENTPD1", + "CD271": "NGFR", + "Perilipin": "PLIN1", + "CD137": "TNFRSF9", + "Caveolin1": "CAV1", + "BCL6": "BCL6", + "CD25": "IL2RA", + "VISTA": "VSIR", + "Keratin10": "KRT10", + "CD33": "CD33", + "CD209": "CD209", + "GFAP": "GFAP", + "INOS": "NOS2", + "NKG2D": "KLRK1", + "Fibronectin": "FN1", + "SOX10": "SOX10", + "Clusterin": "CLU", + "Keratin14": "KRT14", + "Na/K ATPase": "ATP1A1", + "DAPI": "DAPI", + "MMP12": "MMP12", + "CD152": "CTLA4", + "CD162": "SELPLG", + "PCNA": "PCNA", + "CD11c": "ITGAX", + "LAG3": "LAG3", + "CD134": "TNFRSF4", + "ChromograninA": "CHGA", + "PanCK": "KRT", + "NPM1": "NPM1", + "CD90": "THY1", + "Keratin7": "KRT7", + "CD44": "CD44", + "CD3e": "CD3E", + "HistoneH3": "H3F3A", + "Keratin17": "KRT17", + "CD2": "CD2", + "GLUT1": "SLC2A1", + "c-Myc": "MYC", + "TCRb": "TRB", + "pATM": "ATM", + "CDX2": "CDX2", + "CD7": "CD7", + "CD20": "MS4A1", + "IFNg": "IFNG", + "CD19": "CD19", + "Tenascin-C": "TNC", + "Ki67": "MKI67", + "CD14": "CD14", + "CD208": "CD207", + "PGR": "PGR", + "CXCR5": "CXCR5", + "CXCL13": "CXCL13", + "OX40": "TNFRSF4", + "CD86": "CD86", + "CD138": "SDC1", + "EpCAM": "EPCAM", + "CD69": "CD69", + "RAD51": "RAD51", + "CD30": "TNFRSF8", + "BCL2": "BCL2", + "CD1c": "CD1C", + "CD21": "CR2", + "Keratin8/18": "KRT8", + "PD1": "PDCD1", + "GranzymeB": "GZMB", + "yH2AX": "H2AX", + "PDL1": "CD274", + "Twist": "TWIST1", + "Keratin19": "KRT19", + "CollagenIV": "COL4A1", + "CD47": "CD47", + "CD74": "CD74", + "CD94": "KLRD1", + "CXCL12": "CXCL12", + "FAP": "FAP", + "CD107a": "LAMP1", + "CXCR1": "CXCR1", + "CD163": "CD163", + "cPARP-cCasp3": "PARP1", + "CD183": "CXCR3", + "Ly6G": "LY6G6C", + "Tryptase": "TPSB2", + "CD196": "CCR6", + "Arg-1": "ARG1", + "RORgammaT": "RORC", + "GATA3": "GATA3", + "S100A8/9": "S100A8", + "EGFR": "EGFR", + "TCRgammadelta": "TRGC1", + "IRF4": "IRF4", + "CD57": "B3GAT1", + "pSTAT3": "STAT3", + "SNAI2": "SNAI2", + "mTOR": "MTOR", + "bCatenin": "CTNNB1", + "Vimentin": "VIM", + "CollagenI": "COL1A1", + "CD197": "CCR7", + "ICOS": "ICOS", + "CD206": "MRC1", + "Nestin": "NES", + "MMP9": "MMP9", + "CD79": "CD79A", + "CD36": "CD36", + "p-ERK1/2": "MAPK3", + "CD56": "NCAM1", + "CD73": "NT5E", + "CD40": "CD40", + "CXCR2": "CXCR2", + "CD66b": "CEACAM8", + "XCR1": "XCR1", + "aSMA": "ACTA2", + "CD49": "ITGA4", + "TFAM": "TFAM", + "Periostin": "POSTN", + "CD45": "PTPRC", + "CD123": "IL3RA", + "TMEM16A": "ANO1", + "TIGIT": "TIGIT", + "CD45RA": "PTPRC", + "LOX-1": "OLR1", + "PARP1": "PARP1", + "HLA-E": "HLA-E", + "Perlecan": "HSPG2", + "CD34": "CD34", + "TTF1": "NKX2-1", + "RPS6": "RPS6", + "Keratin8": "KRT8", + "CD27": "CD27", + "CD103": "ITGAE", + "Podoplanin": "PDPN", + "TOX": "TOX", + "SOX9": "SOX9", + "Tbet": "TBX21", + "Olig2": "OLIG2", + "HIF1a": "HIF1A", + "PNAD": "PNAD", + "CD68": "CD68", + "Pax5": "PAX5", + "CD71": "TFRC", + "IDO1": "IDO1", + "MCT": "SLC16A1", + "CD146": "MCAM", + "CA9": "CA9", + "CD16": "FCGR3A", + "CD45RO": "PTPRC", + "HLA-ABC": "HLA-A", + "p16": "CDKN2A", + "FoxP3": "FOXP3", + "CD194": "CCR4", + "HER2": "ERBB2", + "HLA-DR": "HLA-DRA", + "H3K27me3": "H3F3A", + "Gal3": "LGALS3", + "Synaptophysin": "SYP", + "CD11b": "ITGAM", + "CD207": "CD207", + "CD141": "THBD", + "Keratin5": "KRT5", + "CD62L": "SELL", + "CD276": "CD276", + "ECad": "CDH1", + "p53": "TP53", + "ERa": "ESR1", + "TIM3": "HAVCR2", + "CD15": "FUT4", + "Siglec8": "SIGLEC8", + "CD31": "PECAM1", + "CX3CR1": "CX3CR1", + "LYVE1": "LYVE1", + "CD66": "CEACAM6", + "TP63": "TP63", + "CD5": "CD5", + "MPO": "MPO", + "HECHA1": "HECHA1", + "HECHA2": "HECHA2", + "HECHA3": "HECHA3", + "PAD": "PAD", +} + +hande_marker = ["HECHA1", "HECHA2", "HECHA3"] + diff --git a/src/utils/eva_kit/global_properties.py b/src/utils/eva_kit/global_properties.py new file mode 100644 index 0000000..d03a486 --- /dev/null +++ b/src/utils/eva_kit/global_properties.py @@ -0,0 +1,24 @@ +import dask +import dask.array as da +from dask import delayed +import numpy as np +from src.utils.image_preprocess import load_img + +@delayed +def get_properties(img_path, channel_masking, mpp): + img_data = load_img(img_path,'').astype(np.float32) + min_vals = np.min(img_data, axis=(1, 2), keepdims=True) + max_vals = np.max(img_data, axis=(1, 2), keepdims=True) + return min_vals, max_vals + +def get_global_properties(img_paths): + + tasks = [get_properties(p, channel_masking, mpp) for p in img_paths] + results = dask.compute(*tasks) + mins = [mn for (mn, mx) in results] + maxs = [mx for (mn, mx) in results] + + min_vals = np.min(np.stack(mins, axis=0), axis=0) + max_vals = np.max(np.stack(maxs, axis=0), axis=0) + + return min_vals, ptp_vals, img_max diff --git a/src/utils/eva_kit/image_funcs.py b/src/utils/eva_kit/image_funcs.py new file mode 100644 index 0000000..bd02304 --- /dev/null +++ b/src/utils/eva_kit/image_funcs.py @@ -0,0 +1,263 @@ +import numpy as np +from tqdm import tqdm +import os +import pandas as pd +import torch +import torch.nn.functional as F +import warnings +import dask +import dask.array as da +from dask import delayed +import re +from dask.diagnostics import ProgressBar +from src.utils.image_preprocess import load_img +from src.utils.eva_kit.constant import marker_to_gene +from tifffile import TiffFile +from src.utils.image_preprocess import calc_mean_std + +def get_mif_normalization_bounds( + image_paths, + channel_mask, + lower_percentile=1.0, + upper_percentile=99.9, + channel_first=True, + ignore_zeros=True, +): + """ + Extract per-channel lower/upper percentile bounds from several MIF images. + + Returns: + lower_bounds, upper_bounds # both shape (C,) + """ + channel_values = None + + for path in image_paths: + img = load_img(path,"").astype(np.float32)[channel_mask,:,:] + + if channel_first: + img = np.transpose(img, (1, 2, 0)) # (C,H,W) -> (H,W,C) + + C = img.shape[-1] + + if channel_values is None: + channel_values = [[] for _ in range(C)] + + for c in range(C): + vals = img[..., c].ravel() + if ignore_zeros: + vals = vals[vals > 0] + + if vals.size > 0: + channel_values[c].append(vals) + + lower_bounds = np.zeros(len(channel_values), dtype=np.float32) + upper_bounds = np.ones(len(channel_values), dtype=np.float32) + + for c, vals_list in enumerate(channel_values): + if len(vals_list) == 0: + continue + + vals = np.concatenate(vals_list) + lower_bounds[c] = np.percentile(vals, lower_percentile) + upper_bounds[c] = np.percentile(vals, upper_percentile) + + return lower_bounds, upper_bounds + +def normalize_mif_image( + img, + lower_bounds, + upper_bounds, + channel_first=True, +): + """ + Clip and normalize one MIF image to [0, 1] using precomputed bounds. + """ + img = img.astype(np.float32) + + if channel_first: + img = np.transpose(img, (1, 2, 0)) # (C,H,W) -> (H,W,C) + + lower_bounds = np.asarray(lower_bounds, dtype=np.float32).reshape(1, 1, -1) + upper_bounds = np.asarray(upper_bounds, dtype=np.float32).reshape(1, 1, -1) + + img = np.clip(img, lower_bounds, upper_bounds) + img = (img - lower_bounds) / (upper_bounds - lower_bounds + 1e-8) + img = np.clip(img, 0.0, 1.0) + + if channel_first: + img = np.transpose(img, (2, 0, 1)) # (H,W,C) -> (C,H,W) + + return img + +def patch_generator(raw, mask, cell_cutout, df): + + xs = df["Centroid.X.px"].round().astype(int).to_numpy() + ys = df["Centroid.Y.px"].round().astype(int).to_numpy() + if "cell_ID" in df.columns: + cell_ids = df["cell_ID"].astype(int).to_numpy() + else: + cell_ids = np.empty(len(xs),dtype=np.uint16) + for i in range(len(xs)): + cell_id = int(mask[ys[i],xs[i]]) + if cell_id == 0: + raise ValueError('Warning! A centroid points outside of cell') + cell_ids[i] = cell_id + + for i in range(len(xs)): + x, y = ys[i], xs[i] + cell_id = cell_ids[i] + + x1 = (x - (cell_cutout // 2)) if (x - (cell_cutout // 2)) >= 0 else 0 + x2 = (x + (cell_cutout // 2)) if (x + (cell_cutout // 2)) < raw.shape[1] else raw.shape[1] + y1 = (y - (cell_cutout // 2)) if (y - (cell_cutout // 2)) >= 0 else 0 + y2 = (y + (cell_cutout // 2)) if (y + (cell_cutout // 2)) < raw.shape[2] else raw.shape[2] + + raw_patch = raw[:, x1:x2, y1:y2] + mask_patch = mask[x1:x2, y1:y2] + + if raw_patch.shape[1] != cell_cutout or raw_patch.shape[2] != cell_cutout: + pre_pad_x = 0 + post_pad_x = 0 + pre_pad_y = 0 + post_pad_y = 0 + + if (x - (cell_cutout // 2)) < 0: + pre_pad_x = abs(x - (cell_cutout // 2)) + if (x + (cell_cutout // 2)) >= raw.shape[1]: + post_pad_x = abs((x + (cell_cutout // 2)) - raw.shape[1]) + + if (y - (cell_cutout // 2)) < 0: + pre_pad_y = abs(y - (cell_cutout // 2)) + if (y + (cell_cutout // 2)) >= raw.shape[2]: + post_pad_y = abs((y + (cell_cutout // 2)) - raw.shape[2]) + + raw_patch = np.pad(raw_patch, ((0, 0), (pre_pad_x, post_pad_x), (pre_pad_y, post_pad_y)), mode='constant', constant_values=0) + mask_patch = np.pad(mask_patch, ((pre_pad_x, post_pad_x), (pre_pad_y, post_pad_y)), mode='constant', constant_values=0) + + assert raw_patch.shape[1] == cell_cutout and raw_patch.shape[2] == cell_cutout, "Patch size mismatch after padding" + assert mask_patch.shape[0] == cell_cutout and mask_patch.shape[1] == cell_cutout, "Mask size mismatch after padding" + + mask_patch = (mask_patch == cell_id).astype(np.uint8) + + yield raw_patch, mask_patch, i + +@delayed +def cell_seg(img_path, + mask_path, + df, + properties, + channel_mask, + cell_cutout): + + raw = TiffFile(img_path).asarray()[channel_mask,:,:] + mask = TiffFile(mask_path).asarray() + + lower_bounds, upper_bounds = properties + raw = normalize_mif_image(raw, lower_bounds, upper_bounds) + + cell_results = torch.empty(( + df.shape[0], + raw.shape[0], + 224, + 224 + ), dtype=torch.float32) + + for raw_patch, mask_patch, i in patch_generator(raw, + mask, + cell_cutout, + df): + raw_patch= torch.tensor(raw_patch) + raw_patch = raw_patch.unsqueeze(0) # (C,H,W) -> (1,C,H,W) + if cell_cutout != 224: + raw_patch = F.interpolate( + raw_patch, + size=(224,224), + mode="bilinear", + align_corners=False + ) + raw_patch = raw_patch.squeeze(0) + cell_results[i] = raw_patch + + np.save(os.path.join(img_path.split('.ome.tif')[0]+'_cells.npy'), cell_results.numpy()) + del cell_results + +def extract_idx(path): + stem = path.split(os.sep)[-1] + if '.ome.tif' in stem: + stem = stem.split('.ome.tif')[0] + elif '_mask.tif' in stem: + stem = stem.split('_mask.tif')[0] + return int(stem.rsplit('_', 1)[1]) + +def image_preprocess(path, + channel_names, + cell_cutout=64, + batch_size=1, + ids_path=""): + + channel_mask = np.array([True]*len(channel_names)) + for idx, channel_name in enumerate(channel_names): + if channel_name not in marker_to_gene.keys(): + print(f"WARNING! {channel_name} is not in GenePT embeddings and will be masked!") + channel_mask[idx] = False + + mask_dir = os.path.join(os.getcwd(),path,'masks') + train_dir = os.path.join(os.getcwd(),path,'train') + test_dir = os.path.join(os.getcwd(),path,'test') + + if ids_path != "": + df = pd.read_csv(os.path.join(os.getcwd(),"data","raw",ids_path)) + else: + df = pd.read_csv([os.path.abspath(os.path.join(path, p)) for p in os.listdir(path) if p.lower().endswith(('csv'))][0]) + + img_names = list(set( + [re.sub(r'_\d+$', '', p.split(os.sep)[-1].split('.ome.tif')[0]) for p in df['Image'].tolist()] + )) + + print(">>> Detected image names:") + [print(p) for p in img_names] + + total_rois = len(os.listdir(mask_dir)) + current_total = 0 + for img_name in img_names: + img_paths = sorted(([os.path.join(train_dir,p) + for p in os.listdir(train_dir) + if p.startswith(img_name) and not p.endswith(('.npy','.pt'))] + + ([os.path.join(test_dir,p) + for p in os.listdir(test_dir) + if p.startswith(img_name) and not p.endswith(('.npy','.pt'))])), + key=extract_idx) + mask_paths = sorted([os.path.join(mask_dir,p) + for p in os.listdir(mask_dir) + if p.startswith(img_name)], + key=extract_idx) + + if len(img_paths) == 0: + print(f">>> Warning! No matching images for {img_name}") + continue + if len(mask_paths) == 0: + print(f">>> Warning! No matching masks for {img_name}") + continue + + properties = get_mif_normalization_bounds(img_paths, channel_mask) + + num_batches = (len(img_paths)+batch_size-1) // batch_size + print(f'>> Processing {img_name}, {len(img_paths)} ROIs in batches of {batch_size}') + for batch_idx in range(num_batches): + if batch_idx < num_batches - 1: + start = batch_idx*batch_size + end = batch_idx*batch_size+batch_size + else: + start = batch_idx*batch_size + end = len(img_paths) + tasks = [cell_seg(img_paths[p], + mask_paths[p], + df[df["Image"]==img_paths[p].split(os.sep)[-1]], + properties, + channel_mask, + cell_cutout) + for p in range(start,end)] + with ProgressBar(): + dask.compute(*tasks) + current_total += len(img_paths) + print(f'{current_total}/{total_rois} ROIs processed') diff --git a/src/utils/eva_kit/layers.py b/src/utils/eva_kit/layers.py new file mode 100644 index 0000000..6754431 --- /dev/null +++ b/src/utils/eva_kit/layers.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2025/03/19 17:50 +@Author : Yufan Liu +@Desc : Some modules and layers for the model +""" + + +import torch +import torch.nn as nn +from timm.layers import Mlp +from timm.models.vision_transformer import Attention, Block +from torch.nn import functional as F + +from src.utils.eva_kit.constant import marker_to_gene + + +# ----------------------------- Marker Embeddings ---------------------------- # +class MarkerEmbeddingGenePT(nn.Module): + """GenePT-based marker embedding module that utilizes pre-computed GenePT embeddings for markers. + + This module maps marker names to their corresponding dense vector representations using + pre-computed GenePT embeddings. For markers without GenePT embeddings, it uses learned + embeddings initialized with Xavier uniform initialization. + + Args: + marker_dict: Dictionary containing pre-computed GenePT embeddings for markers. + unknown_marker_embed_dim: Dimension of embeddings for unknown markers. + Defaults to 3072 (GenePT embedding dimension). + + Note: + - The module pre-initializes embedding layers for all markers that don't have GenePT embeddings + - All embeddings are initialized using Xavier uniform initialization + """ + + def __init__(self, marker_dict, unknown_marker_embed_dim=3072): + super().__init__() + self.genept_embeddings = marker_dict + self.unknown_marker_embeddings = nn.ModuleDict() + self.unknown_marker_embed_dim = unknown_marker_embed_dim + self.register_buffer("_device_tracker", torch.empty(0)) # For robust device tracking + + # Initialize embeddings for all markers that don't have GenePT embeddings + for marker_name in marker_to_gene.keys(): + m_gene = marker_to_gene[marker_name] + if m_gene not in self.genept_embeddings: + self.unknown_marker_embeddings[marker_name] = nn.Embedding(1, self.unknown_marker_embed_dim) + nn.init.xavier_uniform_(self.unknown_marker_embeddings[marker_name].weight) + + def forward(self, marker_names): + """Generate embeddings for a list of marker names. + + Args: + marker_names: List of marker names to embed + + Returns: + Marker embeddings of shape [num_markers, embedding_dim] + where embedding_dim is either GenePT dimension (3072) for known markers + or unknown_marker_embed_dim for unknown markers + """ + target_device = self._device_tracker.device + + embeddings = [] + for m in marker_names: + m_gene = marker_to_gene[m] + if m_gene in self.genept_embeddings: + emb_tensor = torch.tensor(self.genept_embeddings[m_gene], device=target_device, dtype=torch.float) + embeddings.append(emb_tensor) + else: + idx_tensor = torch.zeros(1, dtype=torch.long, device=target_device) + embeddings.append(self.unknown_marker_embeddings[m](idx_tensor).squeeze(0)) + + final_embeddings = torch.stack(embeddings) + return final_embeddings + + +# ---------------------------------------------------------------------------- # + +# --------------------------- Neural network layers -------------------------- # +class MaskedAttention(Attention): + """Attention mechanism with optional masking. + + Extends Attention module to support attention masking for preventing attention to certain positions. + + Args: + dim: Input/output dimension + num_heads: Number of attention heads. Defaults to 8. + qkv_bias: Whether to include bias in qkv projections. Defaults to False. + qk_norm: Whether to normalize query and key. Defaults to False. + proj_bias: Whether to include bias in output projection. Defaults to True. + attn_drop: Dropout rate for attention weights. Defaults to 0.0. + proj_drop: Dropout rate for output projection. Defaults to 0.0. + norm_layer: Normalization layer to use. Defaults to nn.LayerNorm. + fused_attn: Whether to use fused attention. Defaults to True. + """ + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + proj_bias=True, + attn_drop=0.0, + proj_drop=0.0, + norm_layer=nn.LayerNorm, + fused_attn=True, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.fused_attn = fused_attn + + def forward(self, x, attn_mask=None): + """Forward pass with optional attention masking. + + Args: + x: Input tensor [B, N, C] + attn_mask: Optional attention mask. Defaults to None. + + Returns: + Output tensor [B, N, C] + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if attn_mask is not None: + attn_mask = attn_mask == 0 + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0, attn_mask=attn_mask + ) + + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MaskedBlock(Block): + """Transformer block with masked attention support. + + Extends Block module to use MaskedAttention layer for attention masking. + + Args: + dim: Input/output dimension + num_heads: Number of attention heads + mlp_ratio: Ratio of MLP hidden dim to embedding dim. Defaults to 4.0. + qkv_bias: Whether to include bias in qkv projections. Defaults to False. + qk_norm: Whether to normalize query and key. Defaults to False. + proj_bias: Whether to include bias in output projection. Defaults to True. + proj_drop: Dropout rate for output projection. Defaults to 0.0. + attn_drop: Dropout rate for attention weights. Defaults to 0.0. + init_values: Initial value for LayerScale. Defaults to None. + drop_path: Stochastic depth rate. Defaults to 0.0. + act_layer: Activation layer. Defaults to nn.GELU. + norm_layer: Normalization layer. Defaults to nn.LayerNorm. + mlp_layer: MLP layer. Defaults to Mlp. + """ + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_norm=False, + proj_bias=True, + proj_drop=0.0, + attn_drop=0.0, + init_values=None, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + mlp_layer=Mlp, + ): + super().__init__( + dim=dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + init_values=init_values, + drop_path=drop_path, + act_layer=act_layer, + norm_layer=norm_layer, + mlp_layer=mlp_layer, + ) + self.attn = MaskedAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + + def forward(self, x, attn_mask=None): + """Forward pass with optional attention masking. + + Args: + x: Input tensor [B, N, C] + attn_mask: Optional attention mask. Defaults to None. + + Returns: + Output tensor [B, N, C] + """ + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class PatchEmbedChannelFree(nn.Module): + """Channel agnostic patch embedding module that applies the same 2D convolution to each channel. + Each channel is processed independently with the same convolution weights. + The number of input channels can be arbitrary. + """ + + def __init__( + self, + img_size, + token_size=16, + embed_dim=256, + norm_layer=None, + bias=True, + ): + super().__init__() + self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size + self.token_size = (token_size, token_size) if isinstance(token_size, int) else token_size + self.embed_dim = embed_dim + + # Create a single conv layer that will be applied to each channel + self.proj = nn.Conv2d(1, embed_dim, kernel_size=token_size, stride=token_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + # Calculate grid size and number of patches + self.grid_size = (self.img_size[0] // self.token_size[0], self.img_size[1] // self.token_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + def forward(self, x): + """Forward pass of the PatchEmbedChannelFree module. + + Args: + x: Input tensor of shape (B, C, H, W), + where B is batch size, C is number of channels, + H is height, and W is width. + + Returns: + Output tensor of shape (B, C, num_patches, embed_dim). + """ + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + x = x.view(B * C, 1, H, W) + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + x = x.view(B, C, -1, self.embed_dim) + + x = self.norm(x) + return x + +# ---------------------------------------------------------------------------- # diff --git a/src/utils/eva_kit/masking.py b/src/utils/eva_kit/masking.py new file mode 100644 index 0000000..fdc5ba2 --- /dev/null +++ b/src/utils/eva_kit/masking.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2025/03/27 09:42 +@Author : Yufan Liu +@Desc : Maksing strategies +""" +import torch + +def random_masking(ratio: float, strategy): + """Generate masks for reconstruction based on specified strategy. + + Args: + ratio: Masking ratio + strategy: Masking strategy: + - random: Randomly mask 2D content [C, N] + - patch: Mask all channels of selected patches + - channel: Mask all patches of a channel + - he: Mask all H&E channels (last 3) + - mif: Mask all MIF channels + - specified: Mask specified channels (requires channels list during call) + """ + + def random_mask(x, ratio=ratio): + """Randomly mask 2D content across [C, N] plane. + + Args: + x (Tensor): Input tensor of shape [B, C, N, D] + ratio (float): Masking ratio between 0 and 1 + + Returns: + Tensor: Binary mask of shape [C, N] + """ + B, C, N, D = x.shape + device = x.device + + noise = torch.rand(C, N, device=device) + + num_elements = C * N + num_keep = int(num_elements * (1 - ratio)) + + ids_shuffle = torch.argsort(noise.reshape(-1)) + + mask = torch.ones([C, N], device=device) + mask_flat = mask.reshape(-1) + mask_flat[ids_shuffle[:num_keep]] = 0 + mask = mask_flat.reshape(C, N) + + return mask + + def patch_mask(x, ratio=ratio): + """Mask all channels of selected patches. + + Args: + x (Tensor): Input tensor of shape [B, C, N, D] + ratio (float): Masking ratio between 0 and 1 + + Returns: + Tensor: Binary mask of shape [C, N] + """ + B, C, N, D = x.shape + device = x.device + + noise = torch.rand(N, device=device) + + num_elements = N + num_keep = int(num_elements * (1 - ratio)) + + ids_shuffle = torch.argsort(noise) + ids_keep = ids_shuffle[:num_keep] + + mask = torch.ones([N], device=device) + mask[ids_keep] = 0 + + mask = mask.unsqueeze(0).expand(C, -1) # [C, N] + + return mask + + def channel_mask(x, ratio=ratio): + """Mask all patches of selected channels. + + Args: + x (Tensor): Input tensor of shape [B, C, N, D] + ratio (float or int): If float (0-1), use as masking ratio; if int, use as number of channels to mask + + Returns: + Tensor: Binary mask of shape [C, N] + """ + B, C, N, D = x.shape + device = x.device + + noise = torch.rand(C, device=device) + + if isinstance(ratio, int): + # Use as number of channels to mask + num_keep = C - ratio + else: + # Use as ratio + num_keep = int(C * (1 - ratio)) + + ids_shuffle = torch.argsort(noise) + ids_keep = ids_shuffle[:num_keep] + + mask = torch.ones([C], device=device) + mask[ids_keep] = 0 + + mask = mask.unsqueeze(1).expand(-1, N) # [C, N] + + return mask + + def he_mask(x, ratio=ratio): + """Mask H&E channels (last 3 channels). + + Args: + x (Tensor): Input tensor of shape [B, C, N, D] + ratio (float): Ignored for this strategy + + Returns: + Tensor: Binary mask of shape [C, N] + """ + B, C, N, D = x.shape + device = x.device + mask = torch.ones([C], device=device) + mask[:-3] = 0 + + mask = mask.unsqueeze(1).expand(-1, N) # [C, N] + + return mask + + def mif_mask(x, ratio=ratio): + """Mask MIF channels (all except last 3 channels). + + Args: + x (Tensor): Input tensor of shape [B, C, N, D] + ratio (float): Ignored for this strategy + + Returns: + Tensor: Binary mask of shape [C, N] + """ + B, C, N, D = x.shape + device = x.device + mask = torch.ones([C], device=device) + mask[-3:] = 0 + + mask = mask.unsqueeze(1).expand(-1, N) # [C, N] + + return mask + + def specified_mask(x, channels): + """Mask specified channels. + + Args: + x (Tensor): Input tensor of shape [B, C, N, D] + channels (list): List of channel indices to mask + + Returns: + Tensor: Binary mask of shape [C, N] + """ + assert isinstance(channels, list), "channels must be a list" + B, C, N, D = x.shape + device = x.device + mask = torch.zeros([C], device=device) + mask[channels] = 1 + + mask = mask.unsqueeze(1).expand(-1, N) # [C, N] + + return mask + + strategies = { + "random": random_mask, + "patch": patch_mask, + "channel": channel_mask, + "he": he_mask, + "mif": mif_mask, + "specified": specified_mask, + } + + return strategies[strategy] diff --git a/src/utils/eva_kit/pos_embed.py b/src/utils/eva_kit/pos_embed.py new file mode 100644 index 0000000..590c673 --- /dev/null +++ b/src/utils/eva_kit/pos_embed.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn + + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +# ------------------------- Basic Positional Encoding ------------------------ # +class SinCosPositionalEncoding(nn.Module): + """Sinusoidal positional encoding as described in the Transformer paper. + + This module adds positional information to the input embeddings using sine and cosine functions + of different frequencies. + + Args: + d_model (int): Dimension of the model + max_len (int): Maximum sequence length + dropout (float): Dropout probability + """ + + def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1): + """ + Args: + d_model: dimension of the model + max_len: maximum sequence length + dropout: dropout rate + """ + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + # Create positional encoding matrix + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model)) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + # Add batch dimension and store as buffer (won't be trained) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + """Add positional encoding to the input tensor. + + Args: + x (Tensor): Input tensor of shape [batch_size, seq_len, d_model] + + Returns: + Tensor: Input tensor with positional encoding added + """ + x = x + self.pe[:, x.size(1)] + return self.dropout(x) + + +# ---------------------------------------------------------------------------- # + + +# --------------------- Rotary Positional Encoding (Rope) -------------------- # +class RotaryPositionalEmbedding1D(nn.Module): + """1D Rotary Positional Embedding (RoPE) implementation. + + This module implements the rotary positional embedding method which encodes relative positions + through rotation matrices. + + Args: + model_dim (int): Dimension of the model (must be even) + max_seq_length (int): Maximum sequence length + temperature (float): Temperature parameter for the frequency calculation + """ + + def __init__(self, model_dim: int, max_seq_length: int = 1200, temperature: float = 10000.0): + super(RotaryPositionalEmbedding1D, self).__init__() + + assert model_dim % 2 == 0, "Embedding dimension must be multiple of 2 for 1D positional embedding" + self.model_dim = model_dim + + possible_positions = torch.arange(max_seq_length, dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, model_dim, 2, dtype=torch.float32) * -(torch.log(torch.tensor(temperature)) / model_dim) + ) + pos = possible_positions * div_term + sin = torch.sin(pos) + sin = torch.concat([sin, sin], dim=-1) + self.register_buffer("sin", sin) + cos = torch.cos(pos) + cos = torch.concat([cos, cos], dim=-1) + self.register_buffer("cos", cos) + + def invert_negate(self, x): + """Helper function to invert and negate the second half of the input. + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Transformed tensor + """ + return torch.cat([-x[..., self.model_dim // 2 :], x[..., : self.model_dim // 2]], dim=-1) + + def forward(self, x, pos): + """Apply rotary positional encoding to the input tensor. + + Args: + x (Tensor): Input tensor of shape [..., model_dim] + pos (Tensor): Position indices of shape [...] + + Returns: + Tensor: Input tensor with rotary positional encoding applied + """ + x = x * self.cos[pos] + self.invert_negate(x) * self.sin[pos] + return x + + +class RotaryPositionalEmbedding2D(nn.Module): + """2D Rotary Positional Embedding (RoPE) implementation. + + This module extends the 1D RoPE to handle 2D positions by applying separate + rotary embeddings to different halves of the input dimension. + + Args: + model_dim (int): Dimension of the model (must be multiple of 4) + max_pos (int): Maximum position value + temperature (float): Temperature parameter for the frequency calculation + """ + + def __init__(self, model_dim: int, max_pos: int = 1200, temperature: float = 10000.0): + super(RotaryPositionalEmbedding2D, self).__init__() + + assert model_dim % 4 == 0, "Embedding dimension must be multiple of 4 for 2D positional embedding" + self.model_dim = model_dim + self.rope1d = RotaryPositionalEmbedding1D(model_dim // 2, max_pos, temperature) + + def forward(self, x, pos): + """Apply 2D rotary positional encoding to the input tensor. + + Args: + x (Tensor): Input tensor of shape [..., model_dim] + pos (Tensor): 2D position indices of shape [..., 2] + + Returns: + Tensor: Input tensor with 2D rotary positional encoding applied + """ + d = self.model_dim // 2 + + x1 = x[..., :d] + x2 = x[..., d:] + + x1 = self.rope1d(x1, pos.select(dim=-1, index=0)) + x2 = self.rope1d(x2, pos.select(dim=-1, index=1)) + + return torch.cat([x1, x2], dim=-1) + + +# ---------------------------------------------------------------------------- # diff --git a/src/utils/kronos_kit/__init__.py b/src/utils/kronos_kit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/kronos_kit/attention.py b/src/utils/kronos_kit/attention.py new file mode 100644 index 0000000..f768bf8 --- /dev/null +++ b/src/utils/kronos_kit/attention.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Attention)") + else: + warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, return_attn=False) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + if return_attn: + return attn + + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None, return_attn=False) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x, return_attn) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/src/utils/kronos_kit/block.py b/src/utils/kronos_kit/block.py new file mode 100644 index 0000000..ebf2ce8 --- /dev/null +++ b/src/utils/kronos_kit/block.py @@ -0,0 +1,265 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (Block)") + else: + warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, return_attention=False) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + # Add this 2 lines + if return_attention: + return self.attn(self.norm1(x), return_attn=True) + + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, return_attention=False): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list, return_attention) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/src/utils/kronos_kit/dino_head.py b/src/utils/kronos_kit/dino_head.py new file mode 100644 index 0000000..0ace8ff --- /dev/null +++ b/src/utils/kronos_kit/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/src/utils/kronos_kit/drop_path.py b/src/utils/kronos_kit/drop_path.py new file mode 100644 index 0000000..1d640e0 --- /dev/null +++ b/src/utils/kronos_kit/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/src/utils/kronos_kit/image_funcs.py b/src/utils/kronos_kit/image_funcs.py new file mode 100644 index 0000000..a2b3687 --- /dev/null +++ b/src/utils/kronos_kit/image_funcs.py @@ -0,0 +1,210 @@ +import os +import numpy as np +import pandas as pd +import dask +import dask.array as da +from dask import delayed +from dask.diagnostics import ProgressBar +import h5py +import re +import torch +from tifffile import TiffFile +import torch.nn.functional as F + +def process(raw_patch, mask_patch, marker_df): + + patch_markers = [] + marker_ids = [] + for _, r in marker_df.iterrows(): + #marker_name = r['marker_name'] + channel_index = r['channel_id'] + marker_id = r['marker_id'] + marker_mean = r['marker_mean'] + marker_std = r['marker_std'] + marker_patch = raw_patch[channel_index,:,:] + + #marker_max_values = 65535.0 + marker_max_values = np.iinfo(raw_patch.dtype).max + marker = marker_patch / marker_max_values + marker = (marker - marker_mean) / marker_std + + patch_markers.append(torch.tensor(marker)) + marker_ids.append(np.uint16(marker_id)) + + patch_markers = torch.stack(patch_markers, dim=0) + marker_ids = torch.tensor(marker_ids) + cell_mask = np.uint8(mask_patch) + patch_markers = patch_markers * cell_mask + + return patch_markers, marker_ids + +def patch_generator(raw, mask, cell_cutout, df): + + xs = df["Centroid.X.px"].round().astype(int).to_numpy() + ys = df["Centroid.Y.px"].round().astype(int).to_numpy() + if "cell_ID" in df.columns: + cell_ids = df["cell_ID"].astype(int).to_numpy() + else: + cell_ids = np.empty(len(xs),dtype=np.uint16) + for i in range(len(xs)): + cell_id = int(mask[ys[i],xs[i]]) + if cell_id == 0: + raise ValueError('Warning! A centroid points outside of cell') + cell_ids[i] = cell_id + + for i in range(len(xs)): + x, y = ys[i], xs[i] # Switching x and y order because [row,column] = [y,x] + cell_id = cell_ids[i] # Original KRONOS interpret images with [H,W] + + x1 = (x - (cell_cutout // 2)) if (x - (cell_cutout // 2)) >= 0 else 0 + x2 = (x + (cell_cutout // 2)) if (x + (cell_cutout // 2)) < raw.shape[1] else raw.shape[1] + y1 = (y - (cell_cutout // 2)) if (y - (cell_cutout // 2)) >= 0 else 0 + y2 = (y + (cell_cutout // 2)) if (y + (cell_cutout // 2)) < raw.shape[2] else raw.shape[2] + + raw_patch = raw[:, x1:x2, y1:y2] + mask_patch = mask[x1:x2, y1:y2] + + if raw_patch.shape[1] != cell_cutout or raw_patch.shape[2] != cell_cutout: + pre_pad_x = 0 + post_pad_x = 0 + pre_pad_y = 0 + post_pad_y = 0 + + if (x - (cell_cutout // 2)) < 0: + pre_pad_x = abs(x - (cell_cutout // 2)) + if (x + (cell_cutout // 2)) >= raw.shape[1]: + post_pad_x = abs((x + (cell_cutout // 2)) - raw.shape[1]) + + if (y - (cell_cutout // 2)) < 0: + pre_pad_y = abs(y - (cell_cutout // 2)) + if (y + (cell_cutout // 2)) >= raw.shape[2]: + post_pad_y = abs((y + (cell_cutout // 2)) - raw.shape[2]) + + raw_patch = np.pad(raw_patch, ((0, 0), (pre_pad_x, post_pad_x), (pre_pad_y, post_pad_y)), mode='constant', constant_values=0) + mask_patch = np.pad(mask_patch, ((pre_pad_x, post_pad_x), (pre_pad_y, post_pad_y)), mode='constant', constant_values=0) + + assert raw_patch.shape[1] == cell_cutout and raw_patch.shape[2] == cell_cutout, "Patch size mismatch after padding" + assert mask_patch.shape[0] == cell_cutout and mask_patch.shape[1] == cell_cutout, "Mask size mismatch after padding" + + mask_patch = (mask_patch == cell_id).astype(np.uint8) + + yield raw_patch, mask_patch, i + +@delayed +def cell_seg(img_path, + mask_path, + df, + cell_cutout, + marker_df, + channel_names, + channel_mask): + + raw = TiffFile(img_path).asarray() + mask = TiffFile(mask_path).asarray() + + cell_results = torch.empty(( + df.shape[0], + marker_df.shape[0], + 64, + 64 + ), dtype=torch.float32) + + for raw_patch, mask_patch, run in patch_generator(raw, + mask, + cell_cutout, + df): + + patch_markers,_ = process(raw_patch, mask_patch, marker_df) + patch_markers = patch_markers.unsqueeze(0) # (C,H,W) -> (1,C,H,W) + if cell_cutout != 64: + patch_markers = F.interpolate( + patch_markers, + size=(64,64), + mode="bilinear", + align_corners=False + ) + patch_markers = patch_markers.squeeze(0) + cell_results[run] = patch_markers + + np.save(os.path.join(img_path.split('.ome.tif')[0]+'_cells.npy'), cell_results.numpy()) + del cell_results + +def extract_idx(path): + stem = path.split(os.sep)[-1] + if '.ome.tif' in stem: + stem = stem.split('.ome.tif')[0] + elif '_mask.tif' in stem: + stem = stem.split('_mask.tif')[0] + return int(stem.rsplit('_', 1)[1]) + +def image_preprocess(path, + channel_names, + cell_cutout=64, + batch_size=1, + ids_path=""): + + from src.utils.kronos_kit.marker_metadata import check_metadata + channel_mask = check_metadata(path,channel_names) + channel_names = [p for p,m in zip(channel_names,channel_mask) if m] + + mask_dir = os.path.join(os.getcwd(),path,'masks') + train_dir = os.path.join(os.getcwd(),path,'train') + test_dir = os.path.join(os.getcwd(),path,'test') + + marker_df = pd.read_csv(os.path.join(os.getcwd(),'data','marker_info_with_metadata.csv')) + + if ids_path != "": + df = pd.read_csv(os.path.join(os.getcwd(),"data","raw",ids_path)) + else: + df = pd.read_csv([os.path.abspath(os.path.join(path, p)) for p in os.listdir(path) if p.lower().endswith(('csv'))][0]) + + img_names = list(set( + [re.sub(r'_\d+$', '', p.split(os.sep)[-1].split('.ome.tif')[0]) for p in df['Image'].tolist()] + )) + + print(">>> Detected image names:") + [print(p) for p in img_names] + + total_rois = len(os.listdir(mask_dir)) + current_total = 0 + for img_name in img_names: + img_paths = sorted(([os.path.join(train_dir,p) + for p in os.listdir(train_dir) + if p.startswith(img_name) and not p.endswith(('.npy','.pt'))] + + ([os.path.join(test_dir,p) + for p in os.listdir(test_dir) + if p.startswith(img_name) and not p.endswith(('.npy','.pt'))])), + key=extract_idx) + mask_paths = sorted([os.path.join(mask_dir,p) + for p in os.listdir(mask_dir) + if p.startswith(img_name)], + key=extract_idx) + + if len(img_paths) == 0: + print(f">>> Warning! No matching images for {img_name}") + continue + if len(mask_paths) == 0: + print(f">>> Warning! No matching masks for {img_name}") + continue + + num_batches = (len(img_paths)+batch_size-1) // batch_size + print(f'>> Processing {img_name}, {len(img_paths)} ROIs in batches of {batch_size}') + for batch_idx in range(num_batches): + if batch_idx < num_batches - 1: + start = batch_idx*batch_size + end = batch_idx*batch_size+batch_size + else: + start = batch_idx*batch_size + end = len(img_paths) + tasks = [cell_seg(img_paths[p], + mask_paths[p], + df[df["Image"]==img_paths[p].split(os.sep)[-1]], + cell_cutout, + marker_df, + channel_names, + channel_mask) + for p in range(start,end)] + with ProgressBar(): + dask.compute(*tasks) + current_total += len(img_paths) + print(f'{current_total}/{total_rois} ROIs processed') diff --git a/src/utils/kronos_kit/layer_scale.py b/src/utils/kronos_kit/layer_scale.py new file mode 100644 index 0000000..51df0d7 --- /dev/null +++ b/src/utils/kronos_kit/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/src/utils/kronos_kit/marker_metadata.py b/src/utils/kronos_kit/marker_metadata.py new file mode 100644 index 0000000..b7030a1 --- /dev/null +++ b/src/utils/kronos_kit/marker_metadata.py @@ -0,0 +1,141 @@ +import numpy as np +import pandas as pd +import os +from difflib import SequenceMatcher +import pickle + +def check_metadata(path, channel_names): + if channel_names == "": + raise ValueError("Empty channel names argument") + marker_info_path = os.path.join(os.getcwd(),'data','marker_info.csv') + kronos_metadata = pd.read_csv( + os.path.join(os.getcwd(),'src','utils','kronos_kit','marker_metadata.csv')) + + marker_info = pd.DataFrame({ + "channel_id": range(0,len(channel_names)), + "marker_name": channel_names + }) + channel_mask = marker_info['channel_id'].tolist() + + obj = MarkerMetadata(marker_info, kronos_metadata) + obj.get_marker_metadata() + + if len(obj.missing_marker_dict) != 0: + + print(f"There are {len(obj.missing_marker_dict)} markers that do not match with the markers in the pretrained dataset.") + print('Perhaps there is a name mismatch?') + print(f"Below are the top 5 marker name similarity suggestions for each missing marker:") + print(obj.missing_marker_df) + + print('THE MISSING MARKER WILL BE MASKED!') + missing = list(obj.missing_marker_dict.keys()) + marker_info = marker_info[~marker_info['marker_name'].isin(missing)].reset_index(drop=True) + channel_mask = marker_info['channel_id'].tolist() + channel_mask = [i in set(marker_info['channel_id'].tolist()) for i in range(len(channel_names))] + output_csv_path = os.path.join(os.getcwd(),'data','marker_info_with_metadata.csv') # TODO change output_csv_path to another directory + obj.export_marker_metadata(output_csv_path) + return channel_mask + +class MarkerMetadata: + """ + A class to handle marker metadata operations, including loading, mapping, and exporting marker information. + Attributes: + marker_info (pd.DataFrame): DataFrame containing marker information loaded from a CSV file. + marker_metatdata (pd.DataFrame): DataFrame containing marker metadata loaded from a CSV file. + top_suggestions (int): Number of top suggestions to provide for unmatched markers. + missing_marker_df (pd.DataFrame): DataFrame containing missing markers and their suggestions. + missing_marker_dict (dict): Dictionary to map missing markers to their resolved metadata. + Methods: + __init__(marker_info_csv_path, marker_metadata_csv_path, top_suggestions=5): + Initializes the MarkerMetadata object by loading marker information and metadata from CSV files. + get_marker_metadata(): + Matches marker information with metadata, identifies missing markers, and generates suggestions for unmatched markers. + get_marker_metadata_with_mapping(): + Updates marker information based on user-provided mappings for missing markers. + set_marker_metadata(marker_metadata_dict): + Manually sets marker metadata for specific markers and updates the missing marker records. + export_marker_metadata(output_csv_path): + Exports the updated marker information to a CSV file. + """ + def __init__(self, marker_info_csv, marker_metadata_csv, top_suggestions=5): + self.marker_info = marker_info_csv + self.marker_metatdata = marker_metadata_csv + self.top_suggestions = top_suggestions + + def get_marker_metadata(self): + self.marker_metatdata.set_index("marker_name", inplace=True) + + self.marker_info["marker_id"] = [0 for i in range(self.marker_info.shape[0])] + self.marker_info["marker_mean"] = [0.0 for i in range(self.marker_info.shape[0])] + self.marker_info["marker_std"] = [1.0 for i in range(self.marker_info.shape[0])] + + missing_markers = [] + unmatched_markers = self.marker_metatdata.index.tolist() + + for i, row in self.marker_info.iterrows(): + marker_name = row["marker_name"].upper() + if marker_name not in self.marker_metatdata.index: + missing_markers.append(row["marker_name"]) + continue + unmatched_markers.remove(marker_name) + self.marker_info.loc[i, "marker_id"] = self.marker_metatdata.loc[marker_name, "marker_id"] + self.marker_info.loc[i, "marker_mean"] = self.marker_metatdata.loc[marker_name, "marker_mean"] + self.marker_info.loc[i, "marker_std"] = self.marker_metatdata.loc[marker_name, "marker_std"] + + missing_markers.sort() + unmatched_markers.sort() + + missing_dict = {"Missing Marker": []} + for i in range(self.top_suggestions): + missing_dict[f"Suggestion {i+1}"] = [] + + for missing_marker in missing_markers: + missing_dict["Missing Marker"].append(missing_marker) + similarity_list = np.array([SequenceMatcher(None, missing_marker.upper(), marker_name).ratio() for marker_name in unmatched_markers]) + sorted_index = np.argsort(similarity_list, stable=True) + sorted_index = sorted_index[::-1] + for i in range(self.top_suggestions): + missing_dict[f"Suggestion {i+1}"].append(unmatched_markers[sorted_index[i]]) + self.missing_marker_df = pd.DataFrame(missing_dict) + self.missing_marker_df.set_index("Missing Marker", inplace=True) + + self.missing_marker_dict = {} + for missing_marker in missing_markers: + self.missing_marker_dict[missing_marker] = '' + return self.marker_metatdata, self.marker_info, self.missing_marker_df, self.missing_marker_dict + + def get_marker_metadata_with_mapping(self): + matched_markers = [] + for key in self.missing_marker_dict.keys(): + if self.missing_marker_dict[key] == '': + continue + + value = self.missing_marker_dict[key] + if value in self.marker_metatdata.index: + self.marker_info.loc[self.marker_info["marker_name"] == key, "marker_id"] = self.marker_metatdata.loc[value, "marker_id"] + self.marker_info.loc[self.marker_info["marker_name"] == key, "marker_mean"] = self.marker_metatdata.loc[value, "marker_mean"] + self.marker_info.loc[self.marker_info["marker_name"] == key, "marker_std"] = self.marker_metatdata.loc[value, "marker_std"] + matched_markers.append(key) + else: + print(f"Marker {key} not found in metadata") + + for marker_name in matched_markers: + if marker_name in self.missing_marker_df.index: + self.missing_marker_df.drop(marker_name, inplace=True) + del self.missing_marker_dict[marker_name] + + def set_marker_metadata(self, marker_metadata_dict): + for marker_name in marker_metadata_dict.keys(): + self.marker_info.loc[self.marker_info["marker_name"] == marker_name, "marker_id"] = marker_metadata_dict[marker_name]['marker_id'] + self.marker_info.loc[self.marker_info["marker_name"] == marker_name, "marker_mean"] = marker_metadata_dict[marker_name]['marker_mean'] + self.marker_info.loc[self.marker_info["marker_name"] == marker_name, "marker_std"] = marker_metadata_dict[marker_name]['marker_std'] + if marker_name in self.missing_marker_df.index: + self.missing_marker_df.drop(marker_name, inplace=True) + if marker_name in self.missing_marker_dict.keys(): + del self.missing_marker_dict[marker_name] + + def export_marker_metadata(self, output_csv_path): + self.marker_info = self.marker_info[self.marker_info["marker_id"] != 0] + self.marker_info.reset_index(drop=True, inplace=True) + self.marker_info.to_csv(output_csv_path, index=False) + print(f"Exported marker metadata to {output_csv_path}") diff --git a/src/utils/kronos_kit/mlp.py b/src/utils/kronos_kit/mlp.py new file mode 100644 index 0000000..bbf9432 --- /dev/null +++ b/src/utils/kronos_kit/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/src/utils/kronos_kit/patch_embed.py b/src/utils/kronos_kit/patch_embed.py new file mode 100644 index 0000000..f4cdfb8 --- /dev/null +++ b/src/utils/kronos_kit/patch_embed.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + stride_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + stride_HW = make_2tuple(stride_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=stride_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + B, C, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + # x = self.proj(x) # B C H W + patch_embeddings = [] + for i in range(C): + embed = self.proj(x[:, i, :, :].unsqueeze(1)) + patch_embeddings.append(embed.flatten(2).transpose(1, 2)) + + x = torch.cat(patch_embeddings, dim=1) + x = self.norm(x) + if not self.flatten_embedding: + assert self.flatten_embedding, "flatten_embedding=False not supported. Check the implementation of PatchEmbed." + H, W = embed.size(2), embed.size(3) + x_ = [] + for i in range(C): + x_.append(x[:, i*H*W:(i+1)*H*W, :].reshape(-1, H, W, self.embed_dim)) # B H W C for each marker + x = x_ + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/src/utils/kronos_kit/swiglu_ffn.py b/src/utils/kronos_kit/swiglu_ffn.py new file mode 100644 index 0000000..5e9dafa --- /dev/null +++ b/src/utils/kronos_kit/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + warnings.warn("xFormers is available (SwiGLU)") + else: + warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) From b44d50b681d6e8689bbae769a41733a9cbbde9fc Mon Sep 17 00:00:00 2001 From: Jonathan Olsson Date: Tue, 19 May 2026 11:07:23 +0200 Subject: [PATCH 2/3] Kronos, Deepcell, Eva implementation --- celldownload.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/celldownload.py b/celldownload.py index aad2b96..92ce614 100644 --- a/celldownload.py +++ b/celldownload.py @@ -16,7 +16,12 @@ GenePT_gene_protein_embedding_model_3_text.pickle and store it as: GenePT_embedding.pkl -in src.utils.eva_kit +in src/utils/eva_kit + +KRONOS requires marker_metadata.csv from huggingface + https://huggingface.co/MahmoodLab/KRONOS +Download and place in: +src/utils/kronos_kit """ @@ -35,6 +40,8 @@ def parse_args(): def main(**args): model_path = os.path.join(os.getcwd(),"out","models") + if not os.path.exists(): + os.makedirs(model_path) if args["model"] == "deepcell": from src.utils.download_utils.dct_download import dct_download From bf6ec26834738dd4a483a4bb5549300399c7cb33 Mon Sep 17 00:00:00 2001 From: Jonathan Olsson Date: Tue, 19 May 2026 11:27:57 +0200 Subject: [PATCH 3/3] Kronos, Deepcell, Eva implementation --- cellcontrast.py | 206 +++++++++++++++++++++++++++++------------------- 1 file changed, 125 insertions(+), 81 deletions(-) diff --git a/cellcontrast.py b/cellcontrast.py index 3bfed7f..9407267 100644 --- a/cellcontrast.py +++ b/cellcontrast.py @@ -1,81 +1,125 @@ -import argparse - -def parse_args(): - parser = argparse.ArgumentParser(description="Arguments for image model") - - # Arguments for Image preprocessing - parser.add_argument("--preprocess_dir", type=str, default="data/raw/p2106", - help="Directory in which .tiff files are for preprocessing") - parser.add_argument("--preprocess_channels", type=str, default="", - help="Indices of channels to preprocess, seperated by , and empty if all channels") - parser.add_argument("--calc_mean_std", action="store_true", default=False, - help="Wether or not to calculate mean and std of cell cut outs") - parser.add_argument("--cell_cutout", type=int, default=20, - help="Size*Size cutout of cell, centered on Centroid Cell position") - parser.add_argument("--preprocess_workers", type=int, default=1, - help="Number of Workers to use for cell cutout") - parser.add_argument("--image_preprocess", action="store_true", default=False, - help="Wether or not to preprocess images via ZScore normalisation") - - # General Model Arguments - parser.add_argument("--deterministic", action="store_true", default=False, - help="Wether or not to run NNs deterministicly") - parser.add_argument("--seed", type=int, default=42, - help="Seed for random computations") - parser.add_argument("--root_dir", type=str, default="data/", - help="Where to find the raw/ and processed/ dirs") - parser.add_argument("--raw_subset_dir", type=str, default="TMA1_preprocessed", - help="How the subdir in raw/ and processed/ is called") - parser.add_argument("--batch_size", type=int, default=256, - help="Number of elements per Batch") - parser.add_argument("--epochs", type=int, default=100, - help="Number of epochs for which to train") - parser.add_argument("--num_workers", type=int, default=1, - help="Number of worker processes to be used(loading data etc)") - parser.add_argument("--lr", type=float, default=0.1, - help="Learning rate of model") - parser.add_argument("--weight_decay", type=float, default=5e-6, - help="Weight decay of optimizer") - parser.add_argument("--early_stopping", type=int, default=100, - help="Number of epochs after which to stop model run without improvement to val loss") - parser.add_argument("--output_name", type=str, default="out/models/image_contrast.pt", - help="Path/name of moel for saving") - - # Arguments for image model - parser.add_argument("--warmup_epochs", type=int, default=10, - help="Number of Epochs in which learning rate gets increased") - parser.add_argument("--embed", type=int, default=256, - help="Linear net size used to embed data") - parser.add_argument("--contrast", type=int, default=124, - help="Linear net size on which to calculate the contrast loss") - parser.add_argument("--crop_factor", type=float, default=0.5, - help="Cell Image crop factor for Image augmentation") - parser.add_argument("--resnet", type=str, default="18", - help="What ResNet model to choose, on of 18, 34, 50 and 101") - parser.add_argument("--n_clusters_image", type=int, default=1, - help="Number of Clusters to use for KMeans when only use when >= 1") - parser.add_argument("--train_image_model", action="store_true", default=False, - help="Wether or not to train the Image model") - parser.add_argument("--embed_image_data", action="store_true", default=False, - help="Wether or not to embed data with a given Image model") - return parser.parse_args() - - -def main(**args): - if args['image_preprocess']: - from src.utils.image_preprocess import image_preprocess as ImagePreprocess - ImagePreprocess(path=args['preprocess_dir'], - img_channels=args['preprocess_channels'], - do_mean_std=args['calc_mean_std'], - cell_cutout=args['cell_cutout'], - num_processes=args['preprocess_workers']) - if args['train_image_model']: - from src.run.CellContrastTrain import train as ImageTrain - ImageTrain(**args) - if args['embed_image_data']: - from src.run.CellContrastEmbed import embed as CellContrastEmbed - CellContrastEmbed(**args) - -if __name__ == '__main__': - args = vars(parse_args()) - main(**args) \ No newline at end of file +import argparse + +def parse_args(): + parser = argparse.ArgumentParser(description="Arguments for image model") + + # Arguments for Image preprocessing + parser.add_argument("--preprocess_dir", type=str, default="data/raw/p2106", + help="Directory in which .tiff files are for preprocessing") + parser.add_argument("--preprocess_channels", type=str, default="", + help="Indices of channels to preprocess, seperated by , and empty if all channels") + parser.add_argument("--calc_mean_std", action="store_true", default=False, + help="Wether or not to calculate mean and std of cell cut outs") + parser.add_argument("--cell_cutout", type=int, default=20, + help="Size*Size cutout of cell, centered on Centroid Cell position") + parser.add_argument("--preprocess_workers", type=int, default=1, + help="""Number of threads to use for loading data. Increasing `num_workers` past 24 may result + in large increases in CPU memory footprint. Only recommended for systems with + ``>64 GB`` RAM.""") + parser.add_argument("--image_preprocess", action="store_true", default=False, + help="Wether or not to preprocess images via ZScore normalisation") + + # General Model Arguments + parser.add_argument("--deterministic", action="store_true", default=False, + help="Wether or not to run NNs deterministicly") + parser.add_argument("--seed", type=int, default=42, + help="Seed for random computations") + parser.add_argument("--root_dir", type=str, default="data/", + help="Where to find the raw/ and processed/ dirs") + parser.add_argument("--raw_subset_dir", type=str, default="TMA1_preprocessed", + help="How the subdir in raw/ and processed/ is called") + parser.add_argument("--batch_size", type=int, default=256, + help="Number of elements per Batch") + parser.add_argument("--epochs", type=int, default=100, + help="Number of epochs for which to train") + parser.add_argument("--num_workers", type=int, default=1, + help="Number of worker processes to be used(loading data etc)") + parser.add_argument("--lr", type=float, default=0.1, + help="Learning rate of model") + parser.add_argument("--weight_decay", type=float, default=5e-6, + help="Weight decay of optimizer") + parser.add_argument("--early_stopping", type=int, default=100, + help="Number of epochs after which to stop model run without improvement to val loss") + parser.add_argument("--output_name", type=str, default="out/models/image_contrast.pt", + help="Path/name of moel for saving") + + # Jonathan edit + parser.add_argument("--foundation_model",type=str, default="", + help="Which foundation model to use") + parser.add_argument("--channel_names",nargs='+',type=str, default="", + help=""" + For '--foundation_model deepcell,': + Channel names, example '--channel_names CD3 CD8 CD20' + """) + parser.add_argument("--mpp",type=float, default=0.399, + help=""" + For '--foundation_model deepcell,': + Microns per pixel, passed as float. + """) + parser.add_argument("--cell_ids",type=str, default="", + help=""" + For --foundation_model deepcell/kronos, + cell ids can be extracted from a separate CSV + """) + + # Arguments for image model + parser.add_argument("--warmup_epochs", type=int, default=10, + help="Number of Epochs in which learning rate gets increased") + parser.add_argument("--embed", type=int, default=256, + help="Linear net size used to embed data") + parser.add_argument("--contrast", type=int, default=124, + help="Linear net size on which to calculate the contrast loss") + parser.add_argument("--crop_factor", type=float, default=0.5, + help="Cell Image crop factor for Image augmentation") + parser.add_argument("--resnet", type=str, default="18", + help="What ResNet model to choose, on of 18, 34, 50 and 101") + parser.add_argument("--n_clusters_image", type=int, default=1, + help="Number of Clusters to use for KMeans when only use when >= 1") + parser.add_argument("--train_image_model", action="store_true", default=False, + help="Wether or not to train the Image model") + parser.add_argument("--embed_image_data", action="store_true", default=False, + help="Wether or not to embed data with a given Image model") + return parser.parse_args() + + +def main(**args): + if args['image_preprocess']: + if args['foundation_model'] == 'deepcell': + from src.utils.deepcell_kit.image_funcs import image_preprocess as ImagePreprocessDCT + ImagePreprocessDCT(path=args['preprocess_dir'], + channel_names=args['channel_names'], + cell_cutout=args['cell_cutout'], + mpp=args['mpp'], + batch_size=args['batch_size'], + ids_path=args['cell_ids']) + elif args['foundation_model'] == 'kronos': + from src.utils.kronos_kit.image_funcs import image_preprocess as ImagePreprocessKRONOS + ImagePreprocessKRONOS(path=args['preprocess_dir'], + channel_names=args['channel_names'], + cell_cutout=args['cell_cutout'], + batch_size=args['batch_size'], + ids_path=args['cell_ids']) + elif args['foundation_model'] == 'eva': + from src.utils.eva_kit.image_funcs import image_preprocess as ImagePreprocessEVA + ImagePreprocessEVA(path=args['preprocess_dir'], + channel_names=args['channel_names'], + cell_cutout=args['cell_cutout'], + batch_size=args['batch_size'], + ids_path=args['cell_ids']) + else: + from src.utils.image_preprocess import image_preprocess as ImagePreprocess + ImagePreprocess(path=args['preprocess_dir'], + img_channels=args['preprocess_channels'], + do_mean_std=args['calc_mean_std'], + cell_cutout=args['cell_cutout'], + num_processes=args['preprocess_workers']) + if args['train_image_model']: + from src.run.CellContrastTrain import train as ImageTrain + ImageTrain(**args) + if args['embed_image_data']: + from src.run.CellContrastEmbed import embed as CellContrastEmbed + CellContrastEmbed(**args) + +if __name__ == '__main__': + args = vars(parse_args()) + main(**args)