This release focuses on the public CogFlow training and evaluation pipeline for two datasets:
ratbabel
Three public presets are supported:
rat: standard rat training and evaluationbabel: standard babel training and evaluation
The recommended public entry points are:
train.pyeval.pypub_evaluation.py
conda create -n cogsde python=3.11 -y
conda activate cogsde
pip install -r requirements.txtInstall a PyTorch build that matches your CUDA runtime before running training or evaluation.
Download the public dataset packages and weight packages from:
https://drive.google.com/drive/folders/1yxv7f1Kbmaj-isupohGRdxznwEulZx0G?usp=sharing
then place them in the following locations.
Expected files:
data/rat/rat_pose_train.npy
data/rat/rat_stim_train.npy
data/rat/rat_pose_val.npy
data/rat/rat_stim_val.npy
Optional aliases also supported for evaluation:
data/rat/rat_pose_test.npy
data/rat/rat_stim_test.npy
Expected files:
data/babel/babel_train.npy
data/babel/babel_train_cmd.npy
data/babel/babel_val.npy
data/babel/babel_val_cmd.npy
data/babel/babel_test.npy
data/babel/babel_test_cmd.npy
Place downloaded checkpoints here:
results_rat/cor_rat_fm_mn_std/m3_drift_diffusion/models/checkpoint_best.pt
results_babel/cor_babel_fm_m1_std/m3_drift_diffusion/models/checkpoint_best.pt
python train.py --cfg cfg/full_cfg/cor_rat_fm_mn.yml --exp rat_releaseIf
python train.py --cfg cfg/full_cfg/cor_rat_fm_mn.yml --exp rat_test --enable_dissipativity --dissipativity_weight 0.001python train.py --cfg cfg/full_cfg/cor_babel_fm_m1.yml --exp babel_releasepython eval.py --cfg cfg/full_cfg/cor_rat_eval_mn.yml
python eval.py --cfg cfg/full_cfg/cor_babel_fm_m1.ymlpub_evaluation.py is a quick validation to reproduce the released evaluation presets.
python pub_evaluation.py --npz_path cfg/full_cfg/npz/rat_cogflow.npzpython pub_evaluation.py --npz_path cfg/full_cfg/npz/babel_cogflow.npz