calculate leave 1 subject out accuracy for linear SVM syntax: [ percCorrect w ] = runSVMLeave1Out( featureVect, classLabels, numSamplesPerSubj, expLabels) Inputs: featureVect: (dim x numSamples) matrix of data classLabels: class labels, valued [0,1,2] for visualization, leave empty to not visualize. Originally, class 1 for learn, 0 for didn't learn, 2 for remove because unclear numSamplesPerSubj: as name implies, for splitting up the data into leave 1 subject out cross validation expLabels: optional input usually returned by 'getLeave1OutLabels' function which specifies how to partition the data for leave 1 subject out cross validation Outputs: percCorrect: percentage correctly classified by leave 1 subj out CV w: w vector returned by all the well labeled data predictedVsTrue: cell containing the [predeicted ; true] labels of the data for each fold of the cross validation Note, class 1 for learn, 0 for didn't learn, 2 for remove because unclear
0001 % calculate leave 1 subject out accuracy for linear SVM 0002 % 0003 % syntax: [ percCorrect w ] = runSVMLeave1Out( featureVect, classLabels, numSamplesPerSubj, expLabels) 0004 % 0005 % Inputs: 0006 % featureVect: (dim x numSamples) matrix of data 0007 % classLabels: class labels, valued [0,1,2] for visualization, leave empty to 0008 % not visualize. Originally, class 1 for learn, 0 for didn't learn, 0009 % 2 for remove because unclear 0010 % numSamplesPerSubj: as name implies, for splitting up the data into 0011 % leave 1 subject out cross validation 0012 % expLabels: optional input usually returned by 'getLeave1OutLabels' 0013 % function which specifies how to partition the data for leave 1 0014 % subject out cross validation 0015 % 0016 % Outputs: 0017 % percCorrect: percentage correctly classified by leave 1 subj out CV 0018 % w: w vector returned by all the well labeled data 0019 % predictedVsTrue: cell containing the [predeicted ; true] labels of the 0020 % data for each fold of the cross validation 0021 % 0022 % Note, class 1 for learn, 0 for didn't learn, 2 for remove because unclear 0023 0024 function [ percCorrect w predictedVsTrue ] = runSVMLeave1Out( featureVect, classLabels, ... 0025 numSamplesPerSubj, expLabels ) 0026 0027 % leave one subject out cross validation 0028 [ dim numSamples] = size( featureVect); 0029 0030 % center and scale variables to unit variance 0031 featureVect = featureVect - repmat( mean(featureVect,2), [1,numSamples] ); 0032 featStdev = std( featureVect, 0, 2); 0033 featureVect( featStdev ~= 0,:) = featureVect( featStdev ~= 0,:)./repmat(featStdev(featStdev ~= 0), [1,numSamples]); 0034 0035 if nargin < 4 || isempty(expLabels) 0036 expLabels = getLeave1OutLabels( numSamples, numSamplesPerSubj); 0037 end 0038 0039 numTrials = length(expLabels); 0040 numCorrect = 0; 0041 totalNumTest=0; 0042 options = ['-t 0 -h 0 -e .01' ]; %-t 0 for linear kernel, h 0 for shrinkage huristic 0043 0044 predictedVsTrue = cell( numTrials,1 ); 0045 0046 for i1 = 1:numTrials 0047 0048 trainLabels = classLabels(:,expLabels(i1).train); 0049 trainFeatures = featureVect(:,expLabels(i1).train); 0050 trainFeatures( :, trainLabels==2) = []; 0051 trainLabels( :, trainLabels==2) = []; 0052 0053 testLabels = classLabels(:,expLabels(i1).test); 0054 testFeatures = featureVect(:,expLabels(i1).test); 0055 testFeatures( :, testLabels==2) = []; 0056 testLabels( :, testLabels==2) = []; 0057 0058 if ~isempty( testLabels) 0059 model = svmtrain(trainLabels', trainFeatures', options ); 0060 [predicted, accuracy, probEst] = svmpredict(testLabels',testFeatures', model); 0061 numCorrect = numCorrect + accuracy(1)*.01*length(testLabels); 0062 totalNumTest = totalNumTest+length(testLabels); 0063 0064 predictedVsTrue{i1} = [ predicted'; testLabels]; 0065 0066 end 0067 0068 end 0069 percCorrect = numCorrect/totalNumTest; %length(classLabels~=2); 0070 0071 % normal vector of optimal hyperplane (ALL DATA) 0072 featureVect( :, classLabels==2) = []; 0073 classLabels( :, classLabels==2) = []; 0074 model = svmtrain(classLabels', featureVect', options ); 0075 w = sum( repmat( model.sv_coef', [dim,1]).*full(model.SVs'),2); 0076