From 3997b815c0d675b1b3c791f035760432af570fd3 Mon Sep 17 00:00:00 2001 From: lukasvo76 Date: Wed, 1 Apr 2026 17:02:53 +0200 Subject: [PATCH 1/2] bug fixes to canlab scripts, and completely revamped searchlight_disti --- .../hansen_neurotransmitter_maps.m | 183 +++-- .../@image_vector/image_similarity_plot.m | 97 ++- CanlabCore/@image_vector/searchlightLukas.m | 718 ------------------ .../canlab_pattern_similarity.m | 23 +- .../searchlight_disti_Lukas.m | 330 ++++++++ 5 files changed, 558 insertions(+), 793 deletions(-) delete mode 100644 CanlabCore/@image_vector/searchlightLukas.m create mode 100644 CanlabCore/Statistics_tools/searchlight_disti_Lukas.m diff --git a/CanlabCore/@image_vector/hansen_neurotransmitter_maps.m b/CanlabCore/@image_vector/hansen_neurotransmitter_maps.m index eefa846d..9cbbe033 100644 --- a/CanlabCore/@image_vector/hansen_neurotransmitter_maps.m +++ b/CanlabCore/@image_vector/hansen_neurotransmitter_maps.m @@ -41,6 +41,10 @@ % Indicates group membership for each image as vector, categorical, % or string array % +% **mask** +% Followed by mask on path, default canlab gray matter mask +% Mask will be applied to both fmri_data_obj and hansen maps +% % :Outputs: % % **stats:** @@ -128,6 +132,8 @@ % this function, and debugged plotting for multiple groups in % image_similarity_plot % +% Lukas: Added mask option +% % .. % ------------------------------------------------------------------------- @@ -137,8 +143,9 @@ colors = [1 0 0]; dofigure = true; doplot = true; -similarity_metric = 'corr'; +similarity_metric = 'correlation'; dofixrange = []; +mask = which('gray_matter_mask.nii'); doAverage=0; % ------------------------------------------------------------------------- @@ -150,7 +157,7 @@ allowable_inputs = {'colors' 'doplot' 'similarity_metric' 'dofixrange'}; -keyword_inputs = {'noplot' 'nofigure' 'cosine_similarity' 'doAverage' 'compareGroups'}; +keyword_inputs = {'noplot' 'nofigure' 'cosine_similarity' 'doAverage' 'compareGroups', 'mask'}; % optional inputs with default values - each keyword entered will create a variable of the same name @@ -190,11 +197,22 @@ case 'compareGroups' compareGroups = true; group = varargin{i+1}; + + case 'mask' + mask = varargin{i+1}; end end end +% ------------------------------------------------------------------------- +% INTIALIZE OUTPUT +% ------------------------------------------------------------------------- +stats = struct(); +[hh, hhfill] = deal(' '); +table_group = {}; +multcomp_group = {}; + % ------------------------------------------------------------------------- % MAIN FUNCTION % ------------------------------------------------------------------------- @@ -204,69 +222,144 @@ ntmaps = reorder_and_add_metadata(ntmaps); % These are already gray-matter masked in repo, but make sure: -ntmaps = apply_mask(ntmaps, which('gray_matter_mask.nii')); +ntmaps = apply_mask(ntmaps, mask); % This may not be masked...so mask with gray matter: -fmri_data_obj = apply_mask(fmri_data_obj, which('gray_matter_mask.nii')); +fmri_data_obj = apply_mask(fmri_data_obj, mask); + +% handle tags for figure(s) +if dofigure + tagname = ['Neurotransmitter polar plot ' similarity_metric]; + old = findobj('Tag', tagname); + old = old( strcmp( get(old, 'Type'), 'figure' ) ); -if dofigure - create_figure('Neurotransmitter polar plot') + if ~isempty(old) % Found existing figure window with this tag + create_figure([tagname ' ' num2str(length(old)+1)]) + else + create_figure(tagname) + end end if doplot if ~iscell(colors), colors = {colors}; end + + switch similarity_metric + + case 'cosine_similarity' + + if doAverage==1 + if isempty(dofixrange) + + if exist('compareGroups','var') % added by Lukas: if we want to analyze & plot multiple groups + + groupValues = unique(group, 'stable'); + if size(colors,1) ~= size(groupValues,1) + colors = scn_standard_colors(length(groupValues))'; + end + + [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'average','compareGroups', group); + + else + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'average','Error_STD'); + end + + else + + if exist('compareGroups','var') % % added by Lukas: if we want to analyze & plot multiple groups + + groupValues = unique(group, 'stable'); + if size(colors,1) ~= size(groupValues,1) + colors = scn_standard_colors(length(groupValues))'; + end + + [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange, 'average', 'compareGroups', group); + + else + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange,'average','Error_STD'); + end - if doAverage==1 - if isempty(dofixrange) - - if exist('compareGroups','var') % added by Lukas: if we want to analyze & plot multiple groups - - groupValues = unique(group, 'stable'); - if size(colors,1) ~= size(groupValues,1) - colors = scn_standard_colors(length(groupValues))'; end - - [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'average','compareGroups', group); - + else - - [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'average','Error_STD'); + if isempty(dofixrange) + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure'); + + else % we have fixed range + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange); + + end end + + case 'correlation' - else - - if exist('compareGroups','var') % % added by Lukas: if we want to analyze & plot multiple groups - - groupValues = unique(group, 'stable'); - if size(colors,1) ~= size(groupValues,1) - colors = scn_standard_colors(length(groupValues))'; + if doAverage==1 + if isempty(dofixrange) + + if exist('compareGroups','var') % added by Lukas: if we want to analyze & plot multiple groups + + groupValues = unique(group, 'stable'); + if size(colors,1) ~= size(groupValues,1) + colors = scn_standard_colors(length(groupValues))'; + end + + [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'average','compareGroups', group); + + else + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'average','Error_STD'); + end + + else + + if exist('compareGroups','var') % % added by Lukas: if we want to analyze & plot multiple groups + + groupValues = unique(group, 'stable'); + if size(colors,1) ~= size(groupValues,1) + colors = scn_standard_colors(length(groupValues))'; + end + + [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange, 'average', 'compareGroups', group); + + else + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange,'average','Error_STD'); + end + end - - [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange, 'average', 'compareGroups', group); else - - [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange,'average','Error_STD'); + if isempty(dofixrange) + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure'); + + else % we have fixed range + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange); + + end end - end - - else - if isempty(dofixrange) - - [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure'); - - else % we have fixed range - - [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'plotstyle', 'polar', 'networknames', ntmaps.metadata_table.target, 'colors', colors, 'nofigure', 'dofixrange', dofixrange); - - end - end + end % switch similarity metric else - [stats, hh, hhfill, table_group, multcomp_group] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'noplot'); + + switch similarity_metric + + case 'cosine_similarity' + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, similarity_metric, 'noplot'); + + case 'correlation' + + [stats, hh, hhfill] = image_similarity_plot(fmri_data_obj, 'mapset', ntmaps, 'noplot'); + + end % switch similarity metric -end +end % if doplot end % main function diff --git a/CanlabCore/@image_vector/image_similarity_plot.m b/CanlabCore/@image_vector/image_similarity_plot.m index 281f86e2..b63acdde 100644 --- a/CanlabCore/@image_vector/image_similarity_plot.m +++ b/CanlabCore/@image_vector/image_similarity_plot.m @@ -185,6 +185,10 @@ % - .multcomp_spatial, multiple comparisons of means across % different spatial bases, critical value determined % by Tukey-Kramer method (see multcompare) +% - .group.p list of p-values from between-group ANOVAs (see +% table_group below) +% - .group.q corresponding FDR-corrected p-values +% % **hh:** % Handles to lines % @@ -194,7 +198,7 @@ % **table_group** % multiple one-way ANOVA tables (one for each % spatial basis) with group as column factor (requires -% 'average' to be specified) +% 'average' and 'compareGroups' to be specified) % % **multcomp_group** % mutiple comparisons of means across groups, one output @@ -261,6 +265,10 @@ % 2023/12/19 Lukas Van Oudenhove % - debugged and improved wedge and polarplot code for multiple % groups, see notes in code below for details +% 2026/01/13 Lukas Van Oudenhove +% - added fdr corrected p-values to correlation and ANOVA tables +% - corrected stars for significances for group comparison polar plot + % PRELIMINARIES % ------------------------------------------------------------------------ @@ -285,6 +293,7 @@ % functions, e.g., riverplot mapset = 'bucknerlab'; % 'bucknerlab' +stats = struct(); %initialize output table_group = {}; %initialize output multcomp_group = {}; %initialize output dofigure = true; @@ -321,7 +330,6 @@ force_noaverage = true; case 'cosine_similarity', sim_metric = 'cosine'; -% case 'cosine_similarity', sim_metric = 'corr'; case 'binary_overlap', sim_metric = 'overlap'; @@ -607,19 +615,42 @@ groupValues=unique(group, 'stable'); g=num2cell(groupValues); %create cell array of group numbers - for i=1:size(z,2) %for each spatial basis do an anova across groups + for i=1:size(z,2) % for each spatial basis do an anova across groups - [p, table_group{i}, st]=anova1(z(:,i), group, 'off'); %get anova table - [c,~] = multcompare(st, 'Display', 'off'); %perform multiple comparisons + [stats.group.p(i), table_group{i}, st]=anova1(z(:,i), group, 'off'); %get anova table + [c,~] = multcompare(st, 'Display', 'off'); % perform multiple comparisons multcomp_group{i}=[g(c(:,1)), g(c(:,2)), num2cell(c(:,3:end)), pValueToStars(c(:,end))]; %format table for output + + if length(groupValues) > 1 + starCellArray_bg(i) = pValueToStars(c(:,end)); + end end - + + if strcmp(sim_metric,'corr') + stats.group.descrip = 'one-way ANOVA comparing Fisher''s r to Z transformed point-biserial correlations between groups'; + else + stats.group.descrip = ['one-way ANOVA comparing raw similarity measured by ' sim_metric ' between groups']; + end + + [~, stats.group.q, stats.group.aprioriprob] = mafdr(stats.group.p); + if stats.group.aprioriprob > .99 + stats.group.q = mafdr(stats.group.p, 'BHFDR',true); % B-H rather than Storey if estimated probability of true positives = 1 as in SAS proc multtest + label_FDR_group = 'q_B-H'; + else + for q = 1:size(stats.group.q,2) % implementing the constraint q >= p as in SAS proc multtest + if stats.group.q(1,q) < stats.group.p(1,q) + stats.group.q(1,q) = stats.group.p(1,q); + end + end + label_FDR_group = 'q_Storey'; + end + if printTable for i=1:size(z,2) disp(['Between-group comparisons for ' networknames{i} ':']); disp('--------------------------------------'); - disp(['One-way ANOVA: F(' num2str(table_group{i}{2,3}) ',' num2str(table_group{i}{3,3}) ') = ' num2str(table_group{i}{2,5},3) ', P = ' num2str(table_group{i}{2,6},3)]) + disp(['One-way ANOVA: F(' num2str(table_group{i}{2,3}) ',' num2str(table_group{i}{3,3}) ') = ' num2str(table_group{i}{2,5},3) ', P = ' num2str(table_group{i}{2,6},3) ', ' label_FDR_group ' = ' num2str(stats.group.q(1,i))]) disp(' ') disp('Multiple comparisons of means:') disp(' '); @@ -692,25 +723,51 @@ stats(g).t = stat.tstat'; stats(g).df = stat.df'; starCellArray = pValueToStars(stats(g).p); + + if size(stats(g).p,1) > 1 + + [~, stats(g).q, stats(g).aprioriprob] = mafdr(stats(g).p); + if stats(g).aprioriprob > .99 + stats(g).q = mafdr(stats(g).p, 'BHFDR',true); % B-H rather than Storey if estimated probability of true positives = 0 as in SAS proc multtest + label_FDR = 'q_B-H'; + else + for q = 1:size(stats(g).q,1) % implementing the constraint q >= p as in SAS proc multtest + if stats(g).q(q) < stats(g).p(q) + stats(g).q(q) = stats(g).p(q); + end + end + label_FDR = 'q_Storey'; + end + - %perform repeated measures anova (two way anova with subject as the - %row factor - [~, stats(g).table_spatial, st]=anova2(z_group(~any(isnan(z_group')),:),1,'off'); - [c,~] = multcompare(st,'Display','off'); - stats(g).multcomp_spatial=[networknames(c(:,1))', networknames(c(:,2))', num2cell(c(:,3:end))]; + %perform repeated measures anova (two way anova with subject as the + %row factor + [~, stats(g).table_spatial, st]=anova2(z_group(~any(isnan(z_group')),:),1,'off'); + [c,~] = multcompare(st,'Display','off'); + stats(g).multcomp_spatial=[networknames(c(:,1))', networknames(c(:,2))', num2cell(c(:,3:end))]; - if printTable - disp(['Table of correlations Group:' num2str(g)]); - disp('--------------------------------------'); - disp(stats(g).descrip) + if printTable + if length(groupValues) == 1 + disp('Table of correlations entire sample'); + else + disp(['Table of correlations Group:' num2str(g)]); + end + disp('--------------------------------------'); + disp(stats(g).descrip) - print_matrix([m(:,g) stats(g).t stats(g).p stats(g).sig], {'R_avg' 'T' 'P' 'sig'}, networknames, '%3.4f', starCellArray); + print_matrix([m(:,g) stats(g).t stats(g).p stats(g).sig, stats(g).q], {'R_avg' 'T' 'P' 'sig', label_FDR}, networknames, '%3.4f', starCellArray); - disp(' '); - end + disp(' '); + end - networknames=strcat(networknames, starCellArray'); + if length(groupValues) == 1 + networknames=strcat(networknames, starCellArray'); + else + networknames=strcat(networknames, starCellArray_bg); + end + + end end %groups diff --git a/CanlabCore/@image_vector/searchlightLukas.m b/CanlabCore/@image_vector/searchlightLukas.m deleted file mode 100644 index 8c99a3ae..00000000 --- a/CanlabCore/@image_vector/searchlightLukas.m +++ /dev/null @@ -1,718 +0,0 @@ -function [results_obj, stats, indx] = searchlightLukas(dat, varargin) -% Run searchlight multivariate prediction/classification on an image_vector -% or fmri_data object OR two objects, for cross-prediction. -% -% :Usage: -% :: -% -% [list outputs here] = function_name(list inputs here, [optional inputs]) -% [results_obj, stats, indx] = searchlight(dat, [optional inputs]) -% -% -% :Features: -% - Runs searchlight with standard, pre-defined algorithms -% - Custom-entry definition of holdout sets -% - Can re-use searchlight spheres after initial definition -% - Custom-entry definition of any spheres/regions of interest -% - Uses Matlab's parallel processing toolbox (for) -% -% Type help image_vector.searchlight to display this help information -% -% .. -% Author and copyright information: -% -% Copyright (C) 2014 Tor Wager and... -% -% This program is free software: you can redistribute it and/or modify -% it under the terms of the GNU General Public License as published by -% the Free Software Foundation, either version 3 of the License, or -% (at your option) any later version. -% -% This program is distributed in the hope that it will be useful, -% but WITHOUT ANY WARRANTY; without even the implied warranty of -% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -% GNU General Public License for more details. -% -% You should have received a copy of the GNU General Public License -% along with this program. If not, see . -% .. -% -% -% :Inputs: -% -% **dat:** -% image_vector or fmri_data object with data -% -% **dat.Y:** -% required: true outcomes for each observation (image) in dat -% -% :Optional Inputs:* Keyword followed by input variable: -% -% **r:** -% searchlight radius, voxels -% **dat2:** -% second dataset, for cross-prediction -% **indx:** -% sparse logical matrix. each COLUMN is index of inclusion sets for each region/sphere in searchlight -% This takes a long time to calculate, but can be saved and -% re-used for a given mask -% **holdout_set:** -% Followed by integer vector of which observations belong to which -% holdout set, for cross-validation. This is passed into fmri_data.predict.m. Default is -% empty. -% -% :Outputs: -% -% **results_obj:** -% fmri_data object with results maps -% -% **stats:** -% selected statistics for each sphere in searchlight -% -% **indx:** -% sparse logical matrix. each COLUMN is index of inclusion sets for each region/sphere in searchlight -% * this can be re-used for all data with the same mask/structure. * -% -% -% :Examples: -% :: -% -% % Define a sensible gray-matter mask: -% dat = fmri_data(which('scalped_avg152T1_graymatter.img')); -% dat = threshold(dat, [.8 Inf], 'raw-between'); -% dat = trim_mask(dat); -% -% % Create fake data and holdout indicator index vector -% dat.dat = randn(dat.volInfo.n_inmask, 30); -% dat.Y = dat.dat(111111, :)' + .3 * randn(30, 1); -% holdout_set = ones(6, 1); for i = 2:5, holdout_set = [holdout_set; i*ones(6, 1)]; end -% -% % Run, and run again with existing indx -% pool = parpool(12); % initialize parallel processing (12 cores) -% [results_obj, stats, indx] = searchlight(dat, 'holdout_set', holdout_set); -% results_obj = searchlight(dat, 'holdout_set', holdout_set, 'indx', indx); -% -% :See also: -% region.m, fmri_data.predict.m -% -% .. -% Programmers' notes: -% List dates and changes here, and author of changes -% .. - -% .. -% DEFAULTS AND INPUTS -% .. - -r = 5; % For defining regions -indx = []; % cell; with logical index of inclusion sets for each region/sphere in searchlight -do_online = false; -no_weights = false; % don't save weight objects (can lead to memory issues) - -% For prediction/classification -algorithm_name = 'cv_lassopcr'; -holdout_set = []; -% results_obj = []; - -% process variable arguments -for i = 1:length(varargin) - if ischar(varargin{i}) - switch varargin{i} - % functional commands - case {'r', 'indx', 'algorithm_name', 'dat2', 'holdout_set'} - str = [varargin{i} ' = varargin{i + 1};']; - eval(str) - varargin{i} = []; - varargin{i + 1} = []; - - case {'cross_predict'} - algorithm_name = varargin{i}; - - case {'do_online'} - do_online = true; - indx = []; - - - case {'no_weights'} - no_weights = true; - indx = []; - - otherwise - % use all other inputs as optional arguments to specific - % functions, to be interpreted by them - predfun_inputs{end + 1} = varargin{i}; - - if (i+1) <= length(varargin) && ~ischar(varargin{i + 1}) - predfun_inputs{end + 1} = varargin{i + 1}; - end - - end - end -end - -% n = dat.volInfo.n_inmask; -n = size(dat.dat,1); % lukasvo76 changed since dat.volInfo is not updated after applying mask, causing an index out out bounds error on line 207 > line 401; trim_mask should solve this too I guess - - -if ~do_online - %% Set up indices for spherical searchlight, if not entered previously - - if isempty(indx) - - indx = searchlight_sphere_prep(dat, r); - - else - - fprintf('Using input indx to define regions/spheres...\n'); - - end -end - -%% Check for data and return if empty - -if ~isa(dat, 'fmri_data') || isempty(dat.Y) - - fprintf('Returning indx only: No data in dat.Y to predict or data is not an fmri_data object'); - - return - -end - - -%% Run cross-predict (or other) function - -% Get rough time estimate first - -if ~do_online -predict_time_estimate(dat, indx, algorithm_name, holdout_set); -end - -t = tic; -fprintf('Running prediction in each region...'); - -% for i = 1:n -% -% indxcell{i} = indx(:, i); -% -% end - -output_val = cell(1, n); -voxlist = dat.volInfo.xyzlist; -not_removed = ~dat.removed_voxels; -xyzlist = voxlist(not_removed,:); - -for i = 1:n - - if do_online - - wh_voxels = faster_index(xyzlist, i, r); - - else - - wh_voxels = indx(:,i); - - end - - output_val{i} = predict_wrapper(dat, wh_voxels, algorithm_name, holdout_set); - - if no_weights - - output_val{i}.weight_obj =[]; - - end - -end - -e = toc(t); - -[hour, minute, second] = sec2hms(e); -fprintf(1,'Done in %3.0f hours %3.0f min %2.0f sec\n',hour, minute, second); - - -%% Save results map(s) in object - -results_obj = mean(dat); -stats = [output_val{:}]; -dat_temp = [stats.cverr]'; -results_obj.dat = dat_temp; - - -end % function - - - -% ------------------------------------------------------------------------- -% ------------------------------------------------------------------------- -% -% Sub-functions -% -% ------------------------------------------------------------------------- -% ------------------------------------------------------------------------- - - -% SPHERE PREP -% ------------------------------------------------------------------------- - -function indx = searchlight_sphere_prep(dat, r) - -n = dat.volInfo.n_inmask; -indx = cell(1, n); - -t = tic; -fprintf('Setting up seeds...'); - -for i = 1:n - seed{i} = dat.volInfo.xyzlist(i, :); -end - -e = toc(t); -fprintf('Done in %3.2f sec\n', e); - -% Set up indices for spherical searchlight -% ------------------------------------------------------------------------- -% These could be indices for ROIs, user input, previously saved indices... - -% First, a rough time estimate: -% ------------------------------------------------------------------------- -fprintf('Searchlight sphere construction can take 20 mins or more! (est: 20 mins with 8 processors/gray matter mask)\n'); -fprintf('It can be re-used once created for multiple analyses with the same region definitions\n'); -fprintf('Getting a rough time estimate for how long this will take...\n'); - -n_to_run = min(500, n); -t = tic; - -for i = 1:n_to_run - - mydist = sum([dat.volInfo.xyzlist(:, 1) - seed{i}(1) dat.volInfo.xyzlist(:, 2) - seed{i}(2) dat.volInfo.xyzlist(:, 3) - seed{i}(3)] .^ 2, 2); - indx{i} = mydist <= r.^2; - -end -e = toc(t); -estim = e * n / n_to_run; - -[hour, minute, second] = sec2hms(estim); -fprintf(1,'\nEstimate for whole brain = %3.0f hours %3.0f min %2.0f sec\n',hour, minute, second); - -% Second, do it for all voxels/spheres: -% ------------------------------------------------------------------------- - -t = tic; - -fprintf('Constructing spheres for each seed...'); - -%infdist = inf * ones(size(dat.volInfo.xyzlist(:, 1), 1), 1); - -for i = 1:n - - mydist = sum([dat.volInfo.xyzlist(:, 1) - seed{i}(1) dat.volInfo.xyzlist(:, 2) - seed{i}(2) dat.volInfo.xyzlist(:, 3) - seed{i}(3)] .^ 2, 2); - indx{i} = mydist <= r.^2; - -end - -e = toc(t); -fprintf('Done in %3.2f sec\n', e); - -% Make sparse matrix -t = tic; -fprintf('Sparsifying region indices...'); - -indxcat = cat(2, indx{:}); -indxcat = sparse(indxcat); - -e = toc(t); -fprintf('Done in %3.2f sec\n', e); - - -% NOTE: IdenTICAL but twice as slow -% indx1 = cell(1, n); -% tic -% for i = 1:1000 -% indx1{i} = logical(iimg_xyz2spheres(seed{i}, dat.volInfo.xyzlist, r)); -% end -% toc - - -end % searchlight_sphere_prep - - - - -% PREDICTION -% ------------------------------------------------------------------------- - -function output_val = predict_wrapper(dat, indx, algorithm_name, holdout_set) -% This function wraps predict.m to use its functionality. -% It could be extended to handle optional inputs, etc. -% Right now it is basic. -% Note: pass in indx{i}, which becomes indx here, for ONE -% region/searchlight sphere - -dat.dat = dat.dat(indx, :); -dat.removed_voxels(~indx) = true; - -[cverr, stats] = predict(dat, 'algorithm_name', algorithm_name, 'nfolds', holdout_set, 'useparallel', 1, 'verbose', 0); - -% Parse output -% Only storing important info and full data weight map - we will probably -% want xval weight maps too at some point. We will also probably want a -% p-value and standard errors for accuracy too. - -output_val = struct; -switch algorithm_name - case 'cv_svm' - output_val.cverr = stats.cverr; - output_val.dist_from_hyperplane_xval = stats.dist_from_hyperplane_xval; - output_val.weight_obj = stats.weight_obj; - output_val.intercept = stats.other_output{3}; - - case 'cv_lassopcr' - output_val.yfit = stats.yfit; - output_val.pred_outcome_r = stats.pred_outcome_r; - output_val.weight_obj = stats.weight_obj; - output_val.intercept = stats.other_output{3}; -end - -end % subfunction - - - -% PREDICTION time estimate -% ------------------------------------------------------------------------- - -function predict_time_estimate(dat, indx, algorithm_name, holdout_set) - -t = tic; -fprintf('Getting rough time estimate: Running prediction for up to 500 voxels...'); - -n = dat.volInfo.n_inmask; -output_val = cell(1, n); - -n_to_run = min(500, n); - -parfor i = 1:n_to_run - - output_val{i} = predict_wrapper(dat, indx(:, i), algorithm_name, holdout_set); - -end - -e = toc(t); -fprintf('Done in %3.2f sec\n', e); - -estim = e * n / n_to_run; - -[hour, minute, second] = sec2hms(estim); -fprintf(1,'Estimate for whole brain = %3.0f hours %3.0f min %2.0f sec\n',hour, minute, second); - -end - - -function indx = faster_index(xyz, i, r) -seed = xyz(i,:); -indx = sum([xyz(:,1)-seed(1) xyz(:,2)-seed(2) xyz(:,3)-seed(3)].^2, 2) <= r.^2; -end - - -% Convert time -% ------------------------------------------------------------------------- - - -function [hour, minute, second] = sec2hms(sec) -%SEC2HMS Convert seconds to hours, minutes and seconds. -% -% [HOUR, MINUTE, SECOND] = SEC2HMS(SEC) converts the number of seconds in -% SEC into hours, minutes and seconds. - -hour = fix(sec/3600); % get number of hours -sec = sec - 3600*hour; % remove the hours -minute = fix(sec/60); % get number of minutes -sec = sec - 60*minute; % remove the minutes -second = sec; -end - -% % WANI'S CROSS-CLASSIFICATION FUNCTION - to edit and integrate -% % ------------------------------------------------------------------------- -% -% function [acc, p, se, stats1 stats2] = cv_svm_cross_subf(data1, data2, cv_assign, dobalanced, balanced_ridge, doscale) -% -% % run cv_svm on train data and test data -% -% if doscale -% data1 = rescale(data1, 'zscoreimages'); -% data2 = rescale(data2, 'zscoreimages'); -% end -% -% % preallocate some variables -% predicted11 = NaN(size(data1.Y)); -% predicted12 = NaN(size(data2.Y)); -% predicted21 = NaN(size(data1.Y)); -% predicted22 = NaN(size(data2.Y)); -% -% dist_from_hyper11 = NaN(size(data1.Y)); -% dist_from_hyper12 = NaN(size(data2.Y)); -% dist_from_hyper21 = NaN(size(data1.Y)); -% dist_from_hyper22 = NaN(size(data2.Y)); -% -% % get cv assignment -% u = unique(cv_assign); -% nfolds = length(u); -% [trIdx, teIdx] = deal(cell(1, nfolds)); -% -% for i = 1:length(u) -% teIdx{i} = cv_assign == u(i); -% trIdx{i} = ~teIdx{i}; -% end -% -% stats1.Y = data1.Y; -% stats2.Y = data2.Y; -% -% % Use all -% trainobj1 = data(double(data1.dat'),data1.Y); -% trainobj2 = data(double(data2.dat'),data2.Y); -% -% svmobj1 = svm({'optimizer="andre"','C=1','child=kernel'}); -% if dobalanced, svmobj1.balanced_ridge = balanced_ridge; end -% [~, svmobj1] = train(svmobj1, trainobj1); -% -% svmobj2 = svm({'optimizer="andre"','C=1','child=kernel'}); -% if dobalanced, svmobj2.balanced_ridge = balanced_ridge; end -% [~, svmobj2] = train(svmobj2, trainobj2); -% -% % 1) weights -% stats1.all{1} = get_w(svmobj1)'; -% stats2.all{1} = get_w(svmobj2)'; -% % 2) intercepts -% stats1.all{2} = svmobj1.b0; -% stats2.all{2} = svmobj2.b0; -% -% stats1.all_descrip = {'1:weights 2:intercept 3:distance from hyperplane for test data1 4: dist for data2'}; -% stats2.all_descrip = {'1:weights 2:intercept 3:distance from hyperplane for test data2 4: dist for data1'}; -% -% % Cross-val loop starts here -% -% for i = 1:numel(teIdx) -% -% % Prepare the hot-warm and Rej-friend data in Spider format -% trainobj1 = data(double(data1.dat'),data1.Y); -% trainobj2 = data(double(data2.dat'),data2.Y); -% -% % Select training and test data - trainobj, testobj -% testobj1 = trainobj1; -% testobj2 = trainobj2; -% -% trainobj1.X = trainobj1.X(trIdx{i}, :); -% trainobj1.Y = trainobj1.Y(trIdx{i}, :); -% -% trainobj2.X = trainobj2.X(trIdx{i}, :); -% trainobj2.Y = trainobj2.Y(trIdx{i}, :); -% -% testobj1.X = testobj1.X(teIdx{i}, :); -% testobj1.Y = testobj1.Y(teIdx{i}, :); -% -% testobj2.X = testobj2.X(teIdx{i}, :); -% testobj2.Y = testobj2.Y(teIdx{i}, :); -% -% -% %------------------------------------------------------------------- -% % Train on data1 and test on data1 and data2 -% -% % Set up an SVM object -% svmobj1 = svm({'optimizer="andre"','C=1','child=kernel'}); -% -% if dobalanced -% svmobj1.balanced_ridge = balanced_ridge; -% end -% -% % train on hw data set -% [~, svmobj1] = train(svmobj1, trainobj1); -% -% % Test on both HW and RF test set -% -% testobj11 = test(svmobj1, testobj1); -% testobj12 = test(svmobj1, testobj2); -% -% w = get_w(svmobj1)'; -% b0 = svmobj1.b0; -% -% dist11 = testobj1.X * w + b0; -% dist12 = testobj2.X * w + b0; -% -% if ~isequal((dist11>0)*2-1,testobj11.X) -% error('something is wrong'); -% elseif ~isequal((dist12>0)*2-1,testobj12.X) -% error('something is wrong'); -% end -% -% % Save predictions (cross-val) -% predicted11(teIdx{i}) = testobj11.X; -% predicted12(teIdx{i}) = testobj12.X; -% -% -% %------------------------------------------------------------------- -% % Train on RF and test on HW and RF -% -% % Set up an SVM object -% svmobj2 = svm({'optimizer="andre"','C=1','child=kernel'}); -% -% if dobalanced -% svmobj2.balanced_ridge = balanced_ridge; -% end -% % train on rf data set -% [~, svmobj2] = train(svmobj2, trainobj2); -% -% % Test on both HW and RF test set -% -% testobj21 = test(svmobj2, testobj1); -% testobj22 = test(svmobj2, testobj2); -% -% w = get_w(svmobj2)'; -% b0 = svmobj2.b0; -% -% dist21 = testobj1.X * w + b0; -% dist22 = testobj2.X * w + b0; -% -% if ~isequal((dist21>0)*2-1,testobj21.X) -% error('something is wrong'); -% elseif ~isequal((dist22>0)*2-1,testobj22.X) -% error('something is wrong'); -% end -% -% % Save predictions (cross-val) -% predicted21(teIdx{i}) = testobj21.X; -% predicted22(teIdx{i}) = testobj22.X; -% -% % get stats from svmobj -% % 1) weights -% stats1.cvoutput{i,1} = get_w(svmobj1)'; -% stats2.cvoutput{i,1} = get_w(svmobj2)'; -% % 2) intercepts -% stats1.cvoutput{i,2} = svmobj1.b0; -% stats2.cvoutput{i,2} = svmobj2.b0; -% % 3) distance from hyperplane for test data 1 and 2 -% stats1.cvoutput{i,3} = testobj1.X * stats1.cvoutput{i,1} + stats1.cvoutput{i,2}; -% stats1.cvoutput{i,4} = testobj2.X * stats1.cvoutput{i,1} + stats1.cvoutput{i,2}; -% -% stats2.cvoutput{i,3} = testobj1.X * stats2.cvoutput{i,1} + stats2.cvoutput{i,2}; -% stats2.cvoutput{i,4} = testobj2.X * stats2.cvoutput{i,1} + stats2.cvoutput{i,2}; -% -% stats1.cvoutput_descrip = {'1:weights 2:intercept 3:distance from hyperplane for test data1 4: dist for data2'}; -% stats2.cvoutput_descrip = {'1:weights 2:intercept 3:distance from hyperplane for test data2 4: dist for data1'}; -% -% dist_from_hyper11(teIdx{i}) = stats1.cvoutput{i,3}; -% dist_from_hyper12(teIdx{i}) = stats1.cvoutput{i,4}; -% dist_from_hyper21(teIdx{i}) = stats2.cvoutput{i,3}; -% dist_from_hyper22(teIdx{i}) = stats2.cvoutput{i,4}; -% end -% -% stats1.yfit = predicted11; -% stats2.yfit = predicted22; -% stats1.testfit = predicted12; -% stats2.testfit = predicted21; -% -% stats1.all{3} = dist_from_hyper11; -% stats1.all{4} = dist_from_hyper12; -% stats2.all{3} = dist_from_hyper22; -% stats2.all{4} = dist_from_hyper21; -% -% % RO on H vs. W -% % res_temp = binotest(predicted11(1:(length(predicted11)/2)) == -% %data1.Y(1:(length(predicted11)/2)), 0.5); -% % test_results.accuracy(1,1) = res_temp.prop; -% % test_results.accuracy_se(1,1) = res_temp.SE; -% % test_results.accuracy_p(1,1) = res_temp.p_val; -% % -% % res_temp = binotest(predicted12(1:(length(predicted12)/2)) == -% %data2.Y(1:(length(predicted12)/2)), 0.5); -% % test_results.accuracy(1,2) = res_temp.prop; -% % test_results.accuracy_se(1,2) = res_temp.SE; -% % test_results.accuracy_p(1,2) = res_temp.p_val; -% % -% % res_temp = binotest(predicted21(1:(length(predicted21)/2)) == -% %data1.Y(1:(length(predicted21)/2)), 0.5); -% % test_results.accuracy(1,3) = res_temp.prop; -% % test_results.accuracy_se(1,3) = res_temp.SE; -% % test_results.accuracy_p(1,3) = res_temp.p_val; -% % -% % res_temp = binotest(predicted22(1:(length(predicted22)/2)) == -% %data2.Y(1:(length(predicted22)/2)), 0.5); -% % test_results.accuracy(1,4) = res_temp.prop; -% % test_results.accuracy_se(1,4) = res_temp.SE; -% % test_results.accuracy_p(1,4) = res_temp.p_val; -% -% subjn = numel(teIdx); -% outcome = [true(subjn,1); false(subjn,1)]; -% -% hvsw_dist11 = stats1.all{3}(1:(subjn*2)); -% rvsf_dist12 = stats1.all{4}(1:(subjn*2)); -% hvsw_dist21 = stats2.all{4}(1:(subjn*2)); -% rvsf_dist22 = stats2.all{3}(1:(subjn*2)); -% -% -% % ROC = roc_plot(hvsw_dist11, outcome, 'threshold', 0); -% % test_results.accuracy_thresh0(1) = ROC.accuracy; -% % test_results.accuracy_thresh0_se(1) = ROC.accuracy_se; -% % test_results.accuracy_thresh0_p(1) = ROC.accuracy_p; -% -% try -% ROC = roc_plot(hvsw_dist11, outcome, 'twochoice'); -% acc(1) = ROC.accuracy; -% se(1) = ROC.accuracy_se; -% p(1) = ROC.accuracy_p; -% catch err -% % ROC = roc_plot(hvsw_dist11, outcome, 'twochoice'); -% acc(1) = NaN; -% se(1) = NaN; -% p(1) = NaN; -% end -% -% % -% % ROC = roc_plot_wani(rvsf_dist12, outcome, 'threshold', 0); -% % test_results.accuracy_thresh0(2) = ROC.accuracy; -% % test_results.accuracy_thresh0_se(2) = ROC.accuracy_se; -% % test_results.accuracy_thresh0_p(2) = ROC.accuracy_p; -% % -% try -% ROC = roc_plot(rvsf_dist12, outcome, 'twochoice'); -% acc(2) = ROC.accuracy; -% se(2) = ROC.accuracy_se; -% p(2) = ROC.accuracy_p; -% catch err -% acc(2) = NaN; -% se(2) = NaN; -% p(2) = NaN; -% end -% % -% % -% % ROC = roc_plot(hvsw_dist21, outcome, 'threshold', 0); -% % test_results.accuracy_thresh0(3) = ROC.accuracy; -% % test_results.accuracy_thresh0_se(3) = ROC.accuracy_se; -% % test_results.accuracy_thresh0_p(3) = ROC.accuracy_p; -% % -% try -% ROC = roc_plot(hvsw_dist21, outcome, 'twochoice'); -% acc(3) = ROC.accuracy; -% se(3) = ROC.accuracy_se; -% p(3) = ROC.accuracy_p; -% catch err -% acc(3) = NaN; -% se(3) = NaN; -% p(3) = NaN; -% end -% -% % ROC = roc_plot(rvsf_dist22, outcome, 'threshold', 0); -% % test_results.accuracy_thresh0(4) = ROC.accuracy; -% % test_results.accuracy_thresh0_se(4) = ROC.accuracy_se; -% % test_results.accuracy_thresh0_p(4) = ROC.accuracy_p; -% try -% ROC = roc_plot(rvsf_dist22, outcome, 'twochoice'); -% acc(4) = ROC.accuracy; -% se(4) = ROC.accuracy_se; -% p(4) = ROC.accuracy_p; -% catch err -% acc(4) = NaN; -% se(4) = NaN; -% p(4) = NaN; -% end -% -% close; - -% end diff --git a/CanlabCore/Statistics_tools/canlab_pattern_similarity.m b/CanlabCore/Statistics_tools/canlab_pattern_similarity.m index ea0c0c6d..7345c608 100644 --- a/CanlabCore/Statistics_tools/canlab_pattern_similarity.m +++ b/CanlabCore/Statistics_tools/canlab_pattern_similarity.m @@ -63,6 +63,7 @@ % can use 'treat_zero_as_data', 1 to treat the zeros as data values. % % **exclude_zero_mask_values** +% Excludes zero values in pattern_weights input % % :Outputs: % **similarity_output** @@ -139,6 +140,8 @@ % - added option for treating zero value in the map as real value rather % than missing data % +% 2026/01/13 Lukas Van Oudenhove +% - fixed a bug in main function and subfunction for sim_metric = corr % --------------------------------- % Defaults and optional inputs @@ -216,16 +219,6 @@ end -if exclude_zero_mask_values - - badvals_mask = pattern_weights == 0 | isnan(pattern_weights); - -else - - badvals_mask = isnan(pattern_weights); - -end - % --------------------------------- % Main similarity calculation @@ -245,6 +238,16 @@ switch sim_metric case 'corr' + + if exclude_zero_mask_values + + badvals_mask = pattern_weights(:,i) == 0 | isnan(pattern_weights(:,i)); + + else + + badvals_mask = isnan(pattern_weights(:,i)); + + end similarity_output(:, i) = image_correlation(dat, pattern_weights(:, i), badvals, badvals_mask); diff --git a/CanlabCore/Statistics_tools/searchlight_disti_Lukas.m b/CanlabCore/Statistics_tools/searchlight_disti_Lukas.m new file mode 100644 index 00000000..3895e7f6 --- /dev/null +++ b/CanlabCore/Statistics_tools/searchlight_disti_Lukas.m @@ -0,0 +1,330 @@ +function out = searchlight_disti_Lukas(dat, dist_i, additional_inputs) +%% ======================================================================== +% Accelerated, Parfor-Safe Group-Level Searchlight Decoding +% - Index-based sphere caching (RAM-safe) +% - Subject-wise permutation inference (paired A/B flips) +% - Optional fixed C (from whole-brain nested CV) +% - fmri_data reconstruction fully corrected (no remove_empty errors) +% - Right-tailed voxelwise permutation p-values (acc & AUC) +% +% IMPORTANT: +% dat MUST ALREADY be masked and trimmed before calling this function. +% ======================================================================== + +%% ------------------------------------------------------------------------ +% Parse optional inputs +% ------------------------------------------------------------------------- +do_cross = false; +r = 3; +P = 500; +fixedC = []; % NEW: optional user-supplied C + +for k = 1:length(additional_inputs) + if ischar(additional_inputs{k}) + switch additional_inputs{k} + + case 'dat2' + do_cross = true; + dat2 = additional_inputs{k+1}; + additional_inputs{k}=[]; additional_inputs{k+1}=[]; + + case 'algorithm_name' + algorithm_name = additional_inputs{k+1}; + additional_inputs{k}=[]; additional_inputs{k+1}=[]; + + case 'r' + r = additional_inputs{k+1}; + additional_inputs{k}=[]; additional_inputs{k+1}=[]; + + case 'cv_assign' + cv_assign = additional_inputs{k+1}; + additional_inputs{k}=[]; additional_inputs{k+1}=[]; + + case 'n_permutations' + P = additional_inputs{k+1}; + additional_inputs{k}=[]; additional_inputs{k+1}=[]; + + case 'fixedC' % NEW: fixed SVM C value + fixedC = additional_inputs{k+1}; + additional_inputs{k}=[]; additional_inputs{k+1}=[]; + end + end +end + +%% ------------------------------------------------------------------------ +% Convert to single precision (faster) +% ------------------------------------------------------------------------- +dat.dat = single(dat.dat); +dat.volInfo.xyzlist = single(dat.volInfo.xyzlist); + +if do_cross + dat2.dat = single(dat2.dat); + dat2.volInfo.xyzlist = single(dat2.volInfo.xyzlist); +end + +%% ------------------------------------------------------------------------ +% dist_i corresponds to ALREADY masked+trimmed dat +% ------------------------------------------------------------------------- +vox_to_run = find(dist_i(:))'; +N = numel(vox_to_run); +nvox = dat.volInfo.n_inmask; + +%% ------------------------------------------------------------------------ +% Parfor constants +% ------------------------------------------------------------------------- +Cdat = parallel.pool.Constant(dat); +Calg = parallel.pool.Constant(algorithm_name); +Ccv = parallel.pool.Constant(cv_assign); +CfixedC = parallel.pool.Constant(fixedC); % NEW + +if do_cross + Cdat2 = parallel.pool.Constant(dat2); +else + emptydat = dat; + emptydat.dat = []; + emptydat.removed_voxels = []; + Cdat2 = parallel.pool.Constant(emptydat); +end + +xyz = dat.volInfo.xyzlist; + +%% ======================================================================== +% SPHERE CACHING (index-based) +% ======================================================================== +fprintf('Caching %d spheres (radius=%d)...\n', N, r); + +sphere_cache = cell(N,1); +for ii = 1:N + c = vox_to_run(ii); + dx = xyz(:,1) - xyz(c,1); + dy = xyz(:,2) - xyz(c,2); + dz = xyz(:,3) - xyz(c,3); + sphere_cache{ii} = int32(find((dx.*dx + dy.*dy + dz.*dz) <= r*r)); +end +Cspheres = parallel.pool.Constant(sphere_cache); + +%% ======================================================================== +% TEMPLATE fmri_data OBJECT (prevents huge copying inside parfor) +% ======================================================================== +template = dat; +template.dat = []; +template.removed_voxels = false(size(dat.removed_voxels)); +template.volInfo.xyzlist = []; +template.volInfo.wh_inmask = []; +template.volInfo.n_inmask = []; +Ctemplate = parallel.pool.Constant(template); + +%% ======================================================================== +% STEP 1 — REAL SEARCHLIGHT DECODING +% ======================================================================== +fprintf('Computing REAL searchlight (%d voxels)...\n', N); +q_real = makeProgressTracker(N); + +real_acc = nan(N,1,'single'); +real_auc = nan(N,1,'single'); +real_se = nan(N,1,'single'); +real_r = nan(N,1,'single'); + +parfor idx = 1:N + + base = Cdat.Value; + keep = Cspheres.Value{idx}; + + % --- FAST sphere-specific fmri_data object --- + dat_local = Ctemplate.Value; + dat_local.dat = base.dat(keep,:); + rv = true(size(base.removed_voxels)); + rv(keep) = false; + dat_local.removed_voxels = rv; + + dat_local.volInfo.xyzlist = base.volInfo.xyzlist(keep,:); + dat_local.volInfo.wh_inmask = find(~rv); + dat_local.volInfo.n_inmask = numel(dat_local.volInfo.wh_inmask); + + R = predict_center_fast(dat_local, Calg.Value, Ccv.Value, CfixedC.Value); + + real_acc(idx) = R.acc; + real_auc(idx) = R.AUC; + real_se(idx) = R.se; + real_r(idx) = R.r; + + send(q_real,1); +end + +%% ======================================================================== +% STEP 2 — SUBJECT-WISE PERMUTATIONS +% ======================================================================== +fprintf('Running %d permutations...\n', P); +q_perm = makeProgressTracker(P); + +Nsub = length(Cdat.Value.Y)/2; + +perm_count_acc = zeros(N,1,'uint32'); +perm_count_auc = zeros(N,1,'uint32'); + +parfor p = 1:P + + base = Cdat.Value; + dat_perm = base; + + % --- subject-wise paired sign-flips --- + flipvec = rand(Nsub,1) > 0.5; + for s = 1:Nsub + a = s; b = s + Nsub; + if flipvec(s) + dat_perm.Y(a) = -1; + dat_perm.Y(b) = 1; + tmp = dat_perm.dat(a,:); + dat_perm.dat(a,:) = dat_perm.dat(b,:); + dat_perm.dat(b,:) = tmp; + else + dat_perm.Y(a) = 1; + dat_perm.Y(b) = -1; + end + end + + local_acc = zeros(N,1,'uint32'); + local_auc = zeros(N,1,'uint32'); + + for idx = 1:N + + keep = Cspheres.Value{idx}; + + dat_local = Ctemplate.Value; + dat_local.dat = dat_perm.dat(keep,:); + rv = true(size(dat_perm.removed_voxels)); + rv(keep) = false; + dat_local.removed_voxels = rv; + + dat_local.volInfo.xyzlist = base.volInfo.xyzlist(keep,:); + dat_local.volInfo.wh_inmask = find(~rv); + dat_local.volInfo.n_inmask = numel(dat_local.volInfo.wh_inmask); + + R = predict_center_fast(dat_local, Calg.Value, Ccv.Value, CfixedC.Value); + + if R.acc >= real_acc(idx), local_acc(idx) = 1; end + if R.AUC >= real_auc(idx), local_auc(idx) = 1; end + end + + perm_count_acc = perm_count_acc + local_acc; + perm_count_auc = perm_count_auc + local_auc; + + send(q_perm,1); +end + +p_acc = (perm_count_acc + 1)/(P+1); +p_auc = (perm_count_auc + 1)/(P+1); + +%% ======================================================================== +% STEP 3 — BUILD OUTPUT +% ======================================================================== +out = struct(); +out.test_results = cell(1,1); + +full_acc = nan(nvox,1,'single'); +full_auc = nan(nvox,1,'single'); +full_se = nan(nvox,1,'single'); +full_r = nan(nvox,1,'single'); +full_pacc = nan(nvox,1,'single'); +full_pauc = nan(nvox,1,'single'); + +full_acc(vox_to_run) = real_acc; +full_auc(vox_to_run) = real_auc; +full_se(vox_to_run) = real_se; +full_r(vox_to_run) = real_r; +full_pacc(vox_to_run) = p_acc; +full_pauc(vox_to_run) = p_auc; + +out.test_results{1}.acc = full_acc; +out.test_results{1}.AUC = full_auc; +out.test_results{1}.se = full_se; +out.test_results{1}.r = full_r; +out.test_results{1}.p_acc_perm = full_pacc; +out.test_results{1}.p_auc_perm = full_pauc; + +end + +%% ======================================================================== +% HELPER FUNCTIONS +% ======================================================================== + +function R = predict_center_fast(dat_local, alg, cv, fixedC) + [test_Y, dat_local] = setup_testvar(dat_local); + + if ~isempty(fixedC) + [~, stats] = predict(dat_local, ... + 'algorithm_name', alg, ... + 'C', fixedC, ... + 'nfolds', cv, ... + 'verbose', 0); + else + [~, stats] = predict(dat_local, ... + 'algorithm_name', alg, ... + 'nfolds', cv, ... + 'verbose', 0); + end + + R = get_test_results(stats, test_Y, alg); +end + +function R = get_test_results(stats, test_Y, alg) + Y = test_Y{1}; + valid = (Y ~= 0); + ytrue = (Y(valid)==1); + + if contains(alg,'svm') + scores = stats.dist_from_hyperplane_xval(valid); + else + scores = stats.yfit(valid); + end + + acc = mean((scores>0)==ytrue); + + try + [~,~,~,AUC] = perfcurve(ytrue, scores, true); + catch + AUC = NaN; + end + + try + pos = scores(ytrue==1); + neg = scores(ytrue==0); + n1 = numel(pos); n0 = numel(neg); + Q1 = AUC/(2-AUC); + Q2 = 2*AUC*AUC/(1+AUC); + varAUC = (AUC*(1-AUC)+(n1-1)*(Q1-AUC^2)+(n0-1)*(Q2-AUC^2))/(n1*n0); + seAUC = sqrt(max(varAUC,0)); + catch + seAUC = NaN; + end + + R.acc = acc; + R.AUC = AUC; + R.se = seAUC; + R.r = NaN; +end + +function [test_Y, dat] = setup_testvar(dat) + test_Y{1} = dat.Y; + dat.Y = dat.Y(:,1); +end + +function q = makeProgressTracker(totalCount) + progress.total = totalCount; + progress.current = 0; + progress.lastUpdate = tic; + + q = parallel.pool.DataQueue; + afterEach(q,@(~)update()); + + function update() + progress.current = progress.current + 1; + if toc(progress.lastUpdate) > 0.1 || progress.current == progress.total + fprintf('\rProgress: %d/%d (%.1f%%)', ... + progress.current, progress.total, ... + 100*progress.current/progress.total); + progress.lastUpdate = tic; + if progress.current == progress.total, fprintf('\n'); end + end + end +end \ No newline at end of file From 09524bf6fd0578e5f9439a7b578b91a2b223b071 Mon Sep 17 00:00:00 2001 From: lukasvo76 Date: Fri, 19 Jun 2026 13:19:43 +0200 Subject: [PATCH 2/2] finished searchlight_disti_Lukas function --- .../searchlight_disti_Lukas.m | 506 +++++++++--------- 1 file changed, 259 insertions(+), 247 deletions(-) diff --git a/CanlabCore/Statistics_tools/searchlight_disti_Lukas.m b/CanlabCore/Statistics_tools/searchlight_disti_Lukas.m index 3895e7f6..6181db9c 100644 --- a/CanlabCore/Statistics_tools/searchlight_disti_Lukas.m +++ b/CanlabCore/Statistics_tools/searchlight_disti_Lukas.m @@ -1,330 +1,342 @@ function out = searchlight_disti_Lukas(dat, dist_i, additional_inputs) %% ======================================================================== -% Accelerated, Parfor-Safe Group-Level Searchlight Decoding -% - Index-based sphere caching (RAM-safe) -% - Subject-wise permutation inference (paired A/B flips) -% - Optional fixed C (from whole-brain nested CV) -% - fmri_data reconstruction fully corrected (no remove_empty errors) -% - Right-tailed voxelwise permutation p-values (acc & AUC) +% searchlight_disti_Lukas_tfce % -% IMPORTANT: -% dat MUST ALREADY be masked and trimmed before calling this function. +% Group-level searchlight decoding with TFCE-only permutation inference +% using a LOCAL implementation of classic TFCE (Smith & Nichols, 2009). +% +% No dependency on SPM / TDT / CoSMoMVPA TFCE implementations. +% +% ASSUMPTIONS: +% - dat is already masked & trimmed +% - exactly 2 images per subject: [A1..AN B1..BN] +% - dat.Y = [1..1 -1..-1] +% +% OUTPUT: +% out.real_auc +% out.real_t +% out.TFCE_real +% out.TFCE_real_max +% out.TFCE_null_max +% out.p_TFCE_global +% out.p_TFCE_voxel +% out.tfce_stat_image % ======================================================================== -%% ------------------------------------------------------------------------ -% Parse optional inputs -% ------------------------------------------------------------------------- -do_cross = false; -r = 3; -P = 500; -fixedC = []; % NEW: optional user-supplied C +%% -------------------- Defaults -------------------- +r = 3; +P = 100; +alg = 'cv_svm'; +fixedC = []; +cv_assign = []; +% TFCE defaults (Smith & Nichols) +tfce_H = 2; +tfce_E = 0.5; +tfce_conn = 26; + +%% -------------------- Parse inputs -------------------- for k = 1:length(additional_inputs) if ischar(additional_inputs{k}) - switch additional_inputs{k} - - case 'dat2' - do_cross = true; - dat2 = additional_inputs{k+1}; - additional_inputs{k}=[]; additional_inputs{k+1}=[]; - - case 'algorithm_name' - algorithm_name = additional_inputs{k+1}; - additional_inputs{k}=[]; additional_inputs{k+1}=[]; - + switch lower(additional_inputs{k}) case 'r' r = additional_inputs{k+1}; - additional_inputs{k}=[]; additional_inputs{k+1}=[]; - + case 'algorithm_name' + alg = additional_inputs{k+1}; + case 'fixedc' + fixedC = additional_inputs{k+1}; case 'cv_assign' cv_assign = additional_inputs{k+1}; - additional_inputs{k}=[]; additional_inputs{k+1}=[]; - case 'n_permutations' P = additional_inputs{k+1}; - additional_inputs{k}=[]; additional_inputs{k+1}=[]; - - case 'fixedC' % NEW: fixed SVM C value - fixedC = additional_inputs{k+1}; - additional_inputs{k}=[]; additional_inputs{k+1}=[]; + case 'tfce_h' + tfce_H = additional_inputs{k+1}; + case 'tfce_e' + tfce_E = additional_inputs{k+1}; + case 'tfce_conn' + tfce_conn = additional_inputs{k+1}; end end end -%% ------------------------------------------------------------------------ -% Convert to single precision (faster) -% ------------------------------------------------------------------------- -dat.dat = single(dat.dat); -dat.volInfo.xyzlist = single(dat.volInfo.xyzlist); - -if do_cross - dat2.dat = single(dat2.dat); - dat2.volInfo.xyzlist = single(dat2.volInfo.xyzlist); -end - -%% ------------------------------------------------------------------------ -% dist_i corresponds to ALREADY masked+trimmed dat -% ------------------------------------------------------------------------- -vox_to_run = find(dist_i(:))'; -N = numel(vox_to_run); -nvox = dat.volInfo.n_inmask; - -%% ------------------------------------------------------------------------ -% Parfor constants -% ------------------------------------------------------------------------- -Cdat = parallel.pool.Constant(dat); -Calg = parallel.pool.Constant(algorithm_name); -Ccv = parallel.pool.Constant(cv_assign); -CfixedC = parallel.pool.Constant(fixedC); % NEW - -if do_cross - Cdat2 = parallel.pool.Constant(dat2); +if ~isempty(cv_assign) + Ccv = parallel.pool.Constant(cv_assign); else - emptydat = dat; - emptydat.dat = []; - emptydat.removed_voxels = []; - Cdat2 = parallel.pool.Constant(emptydat); + Ccv = []; end -xyz = dat.volInfo.xyzlist; +%% -------------------- Dimensions -------------------- +Y = dat.Y; +Nsub = numel(Y) / 2; +xyz = single(dat.volInfo.xyzlist); -%% ======================================================================== -% SPHERE CACHING (index-based) -% ======================================================================== -fprintf('Caching %d spheres (radius=%d)...\n', N, r); +vox_to_run = find(dist_i(:)); +Nvox = numel(vox_to_run); + +fprintf('Searchlight voxels: %d\n', Nvox); +fprintf('Subjects: %d\n', Nsub); +fprintf('TFCE permutations: %d\n', P); -sphere_cache = cell(N,1); -for ii = 1:N - c = vox_to_run(ii); +%% -------------------- Sphere caching -------------------- +fprintf('Caching spheres (r=%d)...\n', r); +sphere_cache = cell(Nvox,1); +for v = 1:Nvox + c = vox_to_run(v); dx = xyz(:,1) - xyz(c,1); dy = xyz(:,2) - xyz(c,2); dz = xyz(:,3) - xyz(c,3); - sphere_cache{ii} = int32(find((dx.*dx + dy.*dy + dz.*dz) <= r*r)); + sphere_cache{v} = int32(find(dx.^2 + dy.^2 + dz.^2 <= r^2)); end Cspheres = parallel.pool.Constant(sphere_cache); -%% ======================================================================== -% TEMPLATE fmri_data OBJECT (prevents huge copying inside parfor) -% ======================================================================== +%% -------------------- fmri_data template -------------------- template = dat; template.dat = []; template.removed_voxels = false(size(dat.removed_voxels)); -template.volInfo.xyzlist = []; -template.volInfo.wh_inmask = []; -template.volInfo.n_inmask = []; Ctemplate = parallel.pool.Constant(template); %% ======================================================================== -% STEP 1 — REAL SEARCHLIGHT DECODING -% ======================================================================== -fprintf('Computing REAL searchlight (%d voxels)...\n', N); -q_real = makeProgressTracker(N); +% STEP 1 — REAL SEARCHLIGHT +%% ======================================================================== +fprintf('Running REAL searchlight decoding...\n'); -real_acc = nan(N,1,'single'); -real_auc = nan(N,1,'single'); -real_se = nan(N,1,'single'); -real_r = nan(N,1,'single'); +tracker_real = ProgressTracker(Nvox); +q_real = parallel.pool.DataQueue; +afterEach(q_real,@(~)tracker_real.update()); -parfor idx = 1:N +real_auc = nan(Nvox,1,'single'); - base = Cdat.Value; - keep = Cspheres.Value{idx}; +parfor v = 1:Nvox + keep = Cspheres.Value{v}; + base = dat; - % --- FAST sphere-specific fmri_data object --- - dat_local = Ctemplate.Value; - dat_local.dat = base.dat(keep,:); + dloc = Ctemplate.Value; + dloc.dat = base.dat(keep,:); rv = true(size(base.removed_voxels)); - rv(keep) = false; - dat_local.removed_voxels = rv; + rv(keep)=false; + dloc.removed_voxels = rv; - dat_local.volInfo.xyzlist = base.volInfo.xyzlist(keep,:); - dat_local.volInfo.wh_inmask = find(~rv); - dat_local.volInfo.n_inmask = numel(dat_local.volInfo.wh_inmask); - - R = predict_center_fast(dat_local, Calg.Value, Ccv.Value, CfixedC.Value); + if isempty(fixedC) + if isempty(Ccv) + [~,stats] = predict(dloc,'algorithm_name',alg,'verbose',0); + else + [~,stats] = predict(dloc,'algorithm_name',alg,'cv_assign',Ccv.Value,'verbose',0); + end + else + if isempty(Ccv) + [~,stats] = predict(dloc,'algorithm_name',alg,'C',fixedC,'verbose',0); + else + [~,stats] = predict(dloc,'algorithm_name',alg,'C',fixedC,'cv_assign',Ccv.Value,'verbose',0); + end + end - real_acc(idx) = R.acc; - real_auc(idx) = R.AUC; - real_se(idx) = R.se; - real_r(idx) = R.r; + scores = stats.dist_from_hyperplane_xval; + ytrue = (Y==1); + [~,~,~,AUC] = perfcurve(ytrue,scores,true); + real_auc(v)=AUC; send(q_real,1); end -%% ======================================================================== -% STEP 2 — SUBJECT-WISE PERMUTATIONS -% ======================================================================== -fprintf('Running %d permutations...\n', P); -q_perm = makeProgressTracker(P); +%% -------------------- AUC -> t -------------------- +real_t = auc_to_t(real_auc,Nsub,Nsub); -Nsub = length(Cdat.Value.Y)/2; +%% -------------------- TFCE on real map -------------------- +vol_dim = dat.volInfo.dim; +tvol = zeros(vol_dim,'single'); +tvol(~dat.removed_voxels)=real_t; -perm_count_acc = zeros(N,1,'uint32'); -perm_count_auc = zeros(N,1,'uint32'); +TFCE_real_vol = tfce_transform_3d(tvol,tfce_H,tfce_E,tfce_conn); +TFCE_real = TFCE_real_vol(~dat.removed_voxels); +TFCE_real_max = max(TFCE_real); -parfor p = 1:P +%% ======================================================================== +% STEP 2 — TFCE-ONLY PERMUTATIONS +%% ======================================================================== +fprintf('Running TFCE-only permutations...\n'); - base = Cdat.Value; - dat_perm = base; +tracker_perm = ProgressTracker(P); +q_perm = parallel.pool.DataQueue; +afterEach(q_perm,@(~)tracker_perm.update()); - % --- subject-wise paired sign-flips --- - flipvec = rand(Nsub,1) > 0.5; - for s = 1:Nsub - a = s; b = s + Nsub; - if flipvec(s) - dat_perm.Y(a) = -1; - dat_perm.Y(b) = 1; - tmp = dat_perm.dat(a,:); - dat_perm.dat(a,:) = dat_perm.dat(b,:); - dat_perm.dat(b,:) = tmp; - else - dat_perm.Y(a) = 1; - dat_perm.Y(b) = -1; - end +maxTFCE_null = nan(P,1,'single'); + +parfor p=1:P + perm_dat = dat; + flip = rand(Nsub,1)>0.5; + for s=1:Nsub + a=s; b=s+Nsub; + if flip(s), perm_dat.Y([a b])=perm_dat.Y([b a]); end end - local_acc = zeros(N,1,'uint32'); - local_auc = zeros(N,1,'uint32'); + perm_auc = nan(Nvox,1,'single'); + for v=1:Nvox + keep = Cspheres.Value{v}; + dloc = Ctemplate.Value; + dloc.dat = perm_dat.dat(keep,:); + perm_auc(v)=get_auc(Ctemplate.Value,perm_dat,keep,alg,fixedC,Ccv); + end - for idx = 1:N + t_perm = auc_to_t(perm_auc,Nsub,Nsub); + tvol_perm = zeros(vol_dim,'single'); + tvol_perm(~dat.removed_voxels)=t_perm; - keep = Cspheres.Value{idx}; + TFCE_perm = tfce_transform_3d(tvol_perm,tfce_H,tfce_E,tfce_conn); + maxTFCE_null(p)=max(TFCE_perm(~dat.removed_voxels)); + send(q_perm,1); +end + +%% -------------------- TFCE p-maps -------------------- +p_TFCE_global = (sum(maxTFCE_null>=TFCE_real_max)+1)/(P+1); - dat_local = Ctemplate.Value; - dat_local.dat = dat_perm.dat(keep,:); - rv = true(size(dat_perm.removed_voxels)); - rv(keep) = false; - dat_local.removed_voxels = rv; +TFCE_p_voxel = nan(size(TFCE_real),'single'); +for v=1:numel(TFCE_real) + TFCE_p_voxel(v)=(sum(maxTFCE_null>=TFCE_real(v))+1)/(P+1); +end - dat_local.volInfo.xyzlist = base.volInfo.xyzlist(keep,:); - dat_local.volInfo.wh_inmask = find(~rv); - dat_local.volInfo.n_inmask = numel(dat_local.volInfo.wh_inmask); +tfce_stat_img = statistic_image('type','p'); +tfce_stat_img.volInfo = dat.volInfo; +tfce_stat_img.removed_voxels = dat.removed_voxels; +tfce_stat_img.dat = TFCE_p_voxel; +tfce_stat_img.dat_descrip = ['uncorrected TFCE p-values based on ' num2str(P) ' permutations']; +tfce_stat_img.p = TFCE_p_voxel; +tfce_stat_img.p_type = ['uncorrected TFCE p-values based on ' num2str(P) ' permutations']; - R = predict_center_fast(dat_local, Calg.Value, Ccv.Value, CfixedC.Value); +% Apply Storey's FDR +[~, TFCE_q_voxel_fdr, aprioriprob] = mafdr(TFCE_p_voxel); - if R.acc >= real_acc(idx), local_acc(idx) = 1; end - if R.AUC >= real_auc(idx), local_auc(idx) = 1; end - end +% If aprioriprob > 0.99, fallback to Benjamini–Hochberg +if aprioriprob > 0.99 + TFCE_p_voxel_fdr = mafdr(TFCE_p_voxel, 'BHFDR', true); +else +% Enforce constraint q >= p (as in SAS proc multtest) + for j = 1:length(TFCE_q_voxel_fdr) + if TFCE_q_voxel_fdr(j) < p_unc(j) + TFCE_q_voxel_fdr(j) = p_unc(j); + end + end - perm_count_acc = perm_count_acc + local_acc; - perm_count_auc = perm_count_auc + local_auc; +end - send(q_perm,1); +tfce_stat_img_fdr = tfce_stat_img; +if exist('TFCE_p_voxel','var') + tfce_stat_img_fdr.dat = TFCE_p_voxel_fdr; + tfce_stat_img_fdr.dat_descrip = ['fdr corrected TFCE p-values based on ' num2str(P) ' permutations']; + tfce_stat_img_fdr.p = TFCE_p_voxel_fdr; + tfce_stat_img_fdr.p_type = ['fdr corrected TFCE p-values based on ' num2str(P) ' permutations']; +else + tfce_stat_img_fdr.dat = TFCE_q_voxel_fdr; + tfce_stat_img_fdr.dat_descrip = ['fdr corrected TFCE q-values based on ' num2str(P) ' permutations']; + tfce_stat_img_fdr.p = TFCE_q_voxel_fdr; + tfce_stat_img_fdr.p_type = ['fdr corrected TFCE q-values based on ' num2str(P) ' permutations']; end -p_acc = (perm_count_acc + 1)/(P+1); -p_auc = (perm_count_auc + 1)/(P+1); +%% -------------------- Output -------------------- +out.real_auc = real_auc; +out.real_t = real_t; +out.TFCE_real = TFCE_real; +out.TFCE_real_max = TFCE_real_max; +out.TFCE_null_max = maxTFCE_null; +out.p_TFCE_global = p_TFCE_global; +out.p_TFCE_voxel = TFCE_p_voxel; +out.tfce_stat_image = tfce_stat_img; +out.tfce_stat_image_fdr = tfce_stat_img_fdr; + +fprintf('Global TFCE FWE p = %.4f\n',p_TFCE_global); +end %% ======================================================================== -% STEP 3 — BUILD OUTPUT -% ======================================================================== -out = struct(); -out.test_results = cell(1,1); - -full_acc = nan(nvox,1,'single'); -full_auc = nan(nvox,1,'single'); -full_se = nan(nvox,1,'single'); -full_r = nan(nvox,1,'single'); -full_pacc = nan(nvox,1,'single'); -full_pauc = nan(nvox,1,'single'); - -full_acc(vox_to_run) = real_acc; -full_auc(vox_to_run) = real_auc; -full_se(vox_to_run) = real_se; -full_r(vox_to_run) = real_r; -full_pacc(vox_to_run) = p_acc; -full_pauc(vox_to_run) = p_auc; - -out.test_results{1}.acc = full_acc; -out.test_results{1}.AUC = full_auc; -out.test_results{1}.se = full_se; -out.test_results{1}.r = full_r; -out.test_results{1}.p_acc_perm = full_pacc; -out.test_results{1}.p_auc_perm = full_pauc; - +% Helper: Classic TFCE (Smith & Nichols, 2009) +%% ======================================================================== +function tfce = tfce_transform_3d(stat,H,E,conn) + +if nargin<4, conn=26; end +tfce = zeros(size(stat),'single'); + +hvals = unique(stat(:)); +hvals = hvals(hvals>0); +dh = diff([0; hvals]); + +for i=1:numel(hvals) + thr = hvals(i); + bw = stat>=thr; + CC = bwconncomp(bw,conn); + for c=1:CC.NumObjects + idx = CC.PixelIdxList{c}; + tfce(idx)=tfce(idx)+(thr^H)*(numel(idx)^E)*dh(i); + end +end end %% ======================================================================== -% HELPER FUNCTIONS -% ======================================================================== +% Helper: AUC -> t (Hanley & McNeil, 1982) +%% ======================================================================== +function t = auc_to_t(auc,n_pos,n_neg) -function R = predict_center_fast(dat_local, alg, cv, fixedC) - [test_Y, dat_local] = setup_testvar(dat_local); +auc = max(min(auc,1-eps),eps); +Q1 = auc./(2-auc); +Q2 = (2*auc.^2)./(1+auc); - if ~isempty(fixedC) - [~, stats] = predict(dat_local, ... - 'algorithm_name', alg, ... - 'C', fixedC, ... - 'nfolds', cv, ... - 'verbose', 0); - else - [~, stats] = predict(dat_local, ... - 'algorithm_name', alg, ... - 'nfolds', cv, ... - 'verbose', 0); - end +varAUC = (auc.*(1-auc) + (n_pos-1).*(Q1-auc.^2) + ... + (n_neg-1).*(Q2-auc.^2))/(n_pos*n_neg); - R = get_test_results(stats, test_Y, alg); +SE = sqrt(max(varAUC,eps)); +t = (auc-0.5)./SE; end -function R = get_test_results(stats, test_Y, alg) - Y = test_Y{1}; - valid = (Y ~= 0); - ytrue = (Y(valid)==1); +function AUC = get_auc(dloc_template, dat_full, keep, alg, fixedC, Ccv) - if contains(alg,'svm') - scores = stats.dist_from_hyperplane_xval(valid); - else - scores = stats.yfit(valid); - end + % --- dat --- + dloc = dloc_template; % local copy + dloc.dat = dat_full.dat(keep,:); - acc = mean((scores>0)==ytrue); + % --- removed_voxels (LENGTH MUST EQUAL nvox) --- + rv = true(dat_full.volInfo.nvox,1); + rv(keep) = false; + dloc.removed_voxels = rv; - try - [~,~,~,AUC] = perfcurve(ytrue, scores, true); - catch - AUC = NaN; - end + % --- volInfo fields --- + % IMPORTANT: nvox must remain FULL MASK size + dloc.volInfo.nvox = dat_full.volInfo.nvox; - try - pos = scores(ytrue==1); - neg = scores(ytrue==0); - n1 = numel(pos); n0 = numel(neg); - Q1 = AUC/(2-AUC); - Q2 = 2*AUC*AUC/(1+AUC); - varAUC = (AUC*(1-AUC)+(n1-1)*(Q1-AUC^2)+(n0-1)*(Q2-AUC^2))/(n1*n0); - seAUC = sqrt(max(varAUC,0)); - catch - seAUC = NaN; - end + % In-mask voxel indices INTO FULL SPACE + dloc.volInfo.wh_inmask = find(~rv); - R.acc = acc; - R.AUC = AUC; - R.se = seAUC; - R.r = NaN; -end + % xyzlist is ONLY in-mask coordinates + dloc.volInfo.xyzlist = dat_full.volInfo.xyzlist(keep,:); -function [test_Y, dat] = setup_testvar(dat) - test_Y{1} = dat.Y; - dat.Y = dat.Y(:,1); -end + % n_inmask must match dat rows + dloc.volInfo.n_inmask = numel(keep); -function q = makeProgressTracker(totalCount) - progress.total = totalCount; - progress.current = 0; - progress.lastUpdate = tic; - - q = parallel.pool.DataQueue; - afterEach(q,@(~)update()); - - function update() - progress.current = progress.current + 1; - if toc(progress.lastUpdate) > 0.1 || progress.current == progress.total - fprintf('\rProgress: %d/%d (%.1f%%)', ... - progress.current, progress.total, ... - 100*progress.current/progress.total); - progress.lastUpdate = tic; - if progress.current == progress.total, fprintf('\n'); end + % --- predict --- + if isempty(fixedC) + if isempty(Ccv) + [~,stats] = predict(dloc, ... + 'algorithm_name',alg, ... + 'verbose',0); + else + [~,stats] = predict(dloc, ... + 'algorithm_name',alg, ... + 'cv_assign',Ccv.Value, ... + 'verbose',0); + end + else + if isempty(Ccv) + [~,stats] = predict(dloc, ... + 'algorithm_name',alg, ... + 'C',fixedC, ... + 'verbose',0); + else + [~,stats] = predict(dloc, ... + 'algorithm_name',alg, ... + 'C',fixedC, ... + 'cv_assign',Ccv.Value, ... + 'verbose',0); end end + + % --- AUC --- + scores = stats.dist_from_hyperplane_xval; + ytrue = (dat_full.Y == 1); + [~,~,~,AUC] = perfcurve(ytrue, scores, true); end \ No newline at end of file