Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 125 additions & 81 deletions cellcontrast.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,125 @@
import argparse

def parse_args():
parser = argparse.ArgumentParser(description="Arguments for image model")

# Arguments for Image preprocessing
parser.add_argument("--preprocess_dir", type=str, default="data/raw/p2106",
help="Directory in which .tiff files are for preprocessing")
parser.add_argument("--preprocess_channels", type=str, default="",
help="Indices of channels to preprocess, seperated by , and empty if all channels")
parser.add_argument("--calc_mean_std", action="store_true", default=False,
help="Wether or not to calculate mean and std of cell cut outs")
parser.add_argument("--cell_cutout", type=int, default=20,
help="Size*Size cutout of cell, centered on Centroid Cell position")
parser.add_argument("--preprocess_workers", type=int, default=1,
help="Number of Workers to use for cell cutout")
parser.add_argument("--image_preprocess", action="store_true", default=False,
help="Wether or not to preprocess images via ZScore normalisation")

# General Model Arguments
parser.add_argument("--deterministic", action="store_true", default=False,
help="Wether or not to run NNs deterministicly")
parser.add_argument("--seed", type=int, default=42,
help="Seed for random computations")
parser.add_argument("--root_dir", type=str, default="data/",
help="Where to find the raw/ and processed/ dirs")
parser.add_argument("--raw_subset_dir", type=str, default="TMA1_preprocessed",
help="How the subdir in raw/ and processed/ is called")
parser.add_argument("--batch_size", type=int, default=256,
help="Number of elements per Batch")
parser.add_argument("--epochs", type=int, default=100,
help="Number of epochs for which to train")
parser.add_argument("--num_workers", type=int, default=1,
help="Number of worker processes to be used(loading data etc)")
parser.add_argument("--lr", type=float, default=0.1,
help="Learning rate of model")
parser.add_argument("--weight_decay", type=float, default=5e-6,
help="Weight decay of optimizer")
parser.add_argument("--early_stopping", type=int, default=100,
help="Number of epochs after which to stop model run without improvement to val loss")
parser.add_argument("--output_name", type=str, default="out/models/image_contrast.pt",
help="Path/name of moel for saving")

# Arguments for image model
parser.add_argument("--warmup_epochs", type=int, default=10,
help="Number of Epochs in which learning rate gets increased")
parser.add_argument("--embed", type=int, default=256,
help="Linear net size used to embed data")
parser.add_argument("--contrast", type=int, default=124,
help="Linear net size on which to calculate the contrast loss")
parser.add_argument("--crop_factor", type=float, default=0.5,
help="Cell Image crop factor for Image augmentation")
parser.add_argument("--resnet", type=str, default="18",
help="What ResNet model to choose, on of 18, 34, 50 and 101")
parser.add_argument("--n_clusters_image", type=int, default=1,
help="Number of Clusters to use for KMeans when only use when >= 1")
parser.add_argument("--train_image_model", action="store_true", default=False,
help="Wether or not to train the Image model")
parser.add_argument("--embed_image_data", action="store_true", default=False,
help="Wether or not to embed data with a given Image model")
return parser.parse_args()


def main(**args):
if args['image_preprocess']:
from src.utils.image_preprocess import image_preprocess as ImagePreprocess
ImagePreprocess(path=args['preprocess_dir'],
img_channels=args['preprocess_channels'],
do_mean_std=args['calc_mean_std'],
cell_cutout=args['cell_cutout'],
num_processes=args['preprocess_workers'])
if args['train_image_model']:
from src.run.CellContrastTrain import train as ImageTrain
ImageTrain(**args)
if args['embed_image_data']:
from src.run.CellContrastEmbed import embed as CellContrastEmbed
CellContrastEmbed(**args)

if __name__ == '__main__':
args = vars(parse_args())
main(**args)
import argparse

def parse_args():
parser = argparse.ArgumentParser(description="Arguments for image model")

# Arguments for Image preprocessing
parser.add_argument("--preprocess_dir", type=str, default="data/raw/p2106",
help="Directory in which .tiff files are for preprocessing")
parser.add_argument("--preprocess_channels", type=str, default="",
help="Indices of channels to preprocess, seperated by , and empty if all channels")
parser.add_argument("--calc_mean_std", action="store_true", default=False,
help="Wether or not to calculate mean and std of cell cut outs")
parser.add_argument("--cell_cutout", type=int, default=20,
help="Size*Size cutout of cell, centered on Centroid Cell position")
parser.add_argument("--preprocess_workers", type=int, default=1,
help="""Number of threads to use for loading data. Increasing `num_workers` past 24 may result
in large increases in CPU memory footprint. Only recommended for systems with
``>64 GB`` RAM.""")
parser.add_argument("--image_preprocess", action="store_true", default=False,
help="Wether or not to preprocess images via ZScore normalisation")

# General Model Arguments
parser.add_argument("--deterministic", action="store_true", default=False,
help="Wether or not to run NNs deterministicly")
parser.add_argument("--seed", type=int, default=42,
help="Seed for random computations")
parser.add_argument("--root_dir", type=str, default="data/",
help="Where to find the raw/ and processed/ dirs")
parser.add_argument("--raw_subset_dir", type=str, default="TMA1_preprocessed",
help="How the subdir in raw/ and processed/ is called")
parser.add_argument("--batch_size", type=int, default=256,
help="Number of elements per Batch")
parser.add_argument("--epochs", type=int, default=100,
help="Number of epochs for which to train")
parser.add_argument("--num_workers", type=int, default=1,
help="Number of worker processes to be used(loading data etc)")
parser.add_argument("--lr", type=float, default=0.1,
help="Learning rate of model")
parser.add_argument("--weight_decay", type=float, default=5e-6,
help="Weight decay of optimizer")
parser.add_argument("--early_stopping", type=int, default=100,
help="Number of epochs after which to stop model run without improvement to val loss")
parser.add_argument("--output_name", type=str, default="out/models/image_contrast.pt",
help="Path/name of moel for saving")

# Jonathan edit
parser.add_argument("--foundation_model",type=str, default="",
help="Which foundation model to use")
parser.add_argument("--channel_names",nargs='+',type=str, default="",
help="""
For '--foundation_model deepcell,':
Channel names, example '--channel_names CD3 CD8 CD20'
""")
parser.add_argument("--mpp",type=float, default=0.399,
help="""
For '--foundation_model deepcell,':
Microns per pixel, passed as float.
""")
parser.add_argument("--cell_ids",type=str, default="",
help="""
For --foundation_model deepcell/kronos,
cell ids can be extracted from a separate CSV
""")

# Arguments for image model
parser.add_argument("--warmup_epochs", type=int, default=10,
help="Number of Epochs in which learning rate gets increased")
parser.add_argument("--embed", type=int, default=256,
help="Linear net size used to embed data")
parser.add_argument("--contrast", type=int, default=124,
help="Linear net size on which to calculate the contrast loss")
parser.add_argument("--crop_factor", type=float, default=0.5,
help="Cell Image crop factor for Image augmentation")
parser.add_argument("--resnet", type=str, default="18",
help="What ResNet model to choose, on of 18, 34, 50 and 101")
parser.add_argument("--n_clusters_image", type=int, default=1,
help="Number of Clusters to use for KMeans when only use when >= 1")
parser.add_argument("--train_image_model", action="store_true", default=False,
help="Wether or not to train the Image model")
parser.add_argument("--embed_image_data", action="store_true", default=False,
help="Wether or not to embed data with a given Image model")
return parser.parse_args()


def main(**args):
if args['image_preprocess']:
if args['foundation_model'] == 'deepcell':
from src.utils.deepcell_kit.image_funcs import image_preprocess as ImagePreprocessDCT
ImagePreprocessDCT(path=args['preprocess_dir'],
channel_names=args['channel_names'],
cell_cutout=args['cell_cutout'],
mpp=args['mpp'],
batch_size=args['batch_size'],
ids_path=args['cell_ids'])
elif args['foundation_model'] == 'kronos':
from src.utils.kronos_kit.image_funcs import image_preprocess as ImagePreprocessKRONOS
ImagePreprocessKRONOS(path=args['preprocess_dir'],
channel_names=args['channel_names'],
cell_cutout=args['cell_cutout'],
batch_size=args['batch_size'],
ids_path=args['cell_ids'])
elif args['foundation_model'] == 'eva':
from src.utils.eva_kit.image_funcs import image_preprocess as ImagePreprocessEVA
ImagePreprocessEVA(path=args['preprocess_dir'],
channel_names=args['channel_names'],
cell_cutout=args['cell_cutout'],
batch_size=args['batch_size'],
ids_path=args['cell_ids'])
else:
from src.utils.image_preprocess import image_preprocess as ImagePreprocess
ImagePreprocess(path=args['preprocess_dir'],
img_channels=args['preprocess_channels'],
do_mean_std=args['calc_mean_std'],
cell_cutout=args['cell_cutout'],
num_processes=args['preprocess_workers'])
if args['train_image_model']:
from src.run.CellContrastTrain import train as ImageTrain
ImageTrain(**args)
if args['embed_image_data']:
from src.run.CellContrastEmbed import embed as CellContrastEmbed
CellContrastEmbed(**args)

if __name__ == '__main__':
args = vars(parse_args())
main(**args)
55 changes: 55 additions & 0 deletions celldownload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import argparse

"""

Model weights for EVA has to be downloaded from hugging face
https://huggingface.co/yandrewl/Eva
and renamed to:
eva_model.pt
and moved to:
out/models

And GenePT marker embeddings from
https://zenodo.org/records/10833191
Use the file:
GenePT_gene_protein_embedding_model_3_text.pickle
and store it as:
GenePT_embedding.pkl
in src/utils/eva_kit

KRONOS requires marker_metadata.csv from huggingface
https://huggingface.co/MahmoodLab/KRONOS
Download and place in:
src/utils/kronos_kit

"""

def parse_args():
parser = argparse.ArgumentParser(description=
"""
Script for downloading image encoders from foundation models.
""")
parser.add_argument("--model",type=str,default="",
help="""
Name of foundation model to download, available options are:
deepcell, kronos
""")
return parser.parse_args()

def main(**args):

model_path = os.path.join(os.getcwd(),"out","models")
if not os.path.exists():
os.makedirs(model_path)

if args["model"] == "deepcell":
from src.utils.download_utils.dct_download import dct_download
dct_download(model_path)
if args["model"] == "kronos":
from src.utils.download_utils.kr_download import kr_download
kr_download(model_path)

if __name__=="__main__":
args = vars(parse_args())
main(**args)
86 changes: 86 additions & 0 deletions src/data/DeepCellData.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
from torch.utils.data import IterableDataset
import numpy as np
import yaml
import os
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader
from skimage import io
from src.utils.deepcell_kit.config import DCTConfig

class PatchDataset():
"""
Dataset for single-image patchified data.
"""
def __init__(self,
root_dir='data/raw',
raw_subset_dir='',
split='train',
crop_factor=0.5,
n_clusters=1,
save_embed_data=False,
**args):

assert split in ['train', 'test'], f'split must be either train or test, but is {split}'

self.work_dir = os.path.join(os.getcwd(), root_dir, 'raw', raw_subset_dir)
self.img_dir = os.path.join(self.work_dir,split)
self.cells_path = [os.path.join(self.img_dir, p) for p in os.listdir(self.img_dir) if p.lower().endswith('_cells.npy')]
self.dct_config = DCTConfig()

def create_attn_mask(self, sample, max_channels):
# True = padding
# https://pytorch.org/docs/stable/generated/torch.ao.nn.quantizable.MultiheadAttention.html#torch.ao.nn.quantizable.MultiheadAttention.forward
mask = np.full((sample.shape[0], max_channels), True)
mask[:, 0 : sample.shape[1]] = False
return mask

def pad_images(self, sample, max_channels):
paddings = -1.0 # retrieved as a constant from repo (?)
return np.pad(
sample,
((0, 0), (0, max_channels - sample.shape[1]), (0, 0), (0, 0), (0, 0)),
mode="constant",
constant_values=paddings,
)

def save_embed_data(self, model, device='cpu', batch_size=256):
"""
Save model representations of all cells per ROI.

model (torch.Module): deepcell model
device (str): device to operate on
batch_size (int): Number of cells to extract representations from at once
"""
ch_idx_path = [os.path.join(self.work_dir, p) for p in os.listdir(self.work_dir) if p.split(os.sep)[-1] == 'channel_idx.npy'][0]
ch_idx = torch.from_numpy(np.load(ch_idx_path))
model.eval()
with torch.no_grad():
for path in tqdm(self.cells_path, desc='Save embeddings'):

sample = torch.from_numpy(np.load(path)) # (B,C,3,H,W)

attn_mask = self.create_attn_mask(sample, self.dct_config.MAX_NUM_CHANNELS) # (C_max,)
sample = self.pad_images(sample, self.dct_config.MAX_NUM_CHANNELS) # (C_max, 3, H, W)
sample, attn_mask = torch.as_tensor(sample, dtype=torch.float32), torch.as_tensor(attn_mask, dtype=torch.bool)

embed = torch.empty((sample.shape[0], model.embed_size), dtype=torch.float32)
num_batches = (sample.shape[0]+batch_size-1) // batch_size

for batch_idx in range(num_batches):
if batch_idx < num_batches - 1:
idx_start = batch_idx*batch_size
idx_end = batch_idx*batch_size+batch_size
else:
idx_start = batch_idx*batch_size
idx_end = len(embed)
embed[idx_start:idx_end] = model(
sample[idx_start:idx_end].to(device),
ch_idx.to(device),
attn_mask[idx_start:idx_end].to(device)
).to('cpu')
torch.save(embed, os.path.join(path, path.split('.')[0]+'_embed.pt'))
del sample
del attn_mask
del embed
Loading