Home > src > runLDALeave1Out.m

runLDALeave1Out

PURPOSE ^

calculate leave 1 subject out accuracy for Bayes classifier with equal covariance

SYNOPSIS ^

function [ percCorrect w predictedVsTrue] = runLDALeave1Out( featureVect, classLabels,numSamplesPerSubj, expLabels )

DESCRIPTION ^

 calculate leave 1 subject out accuracy for Bayes classifier with equal covariance
 
 syntax: [ percCorrect w ] = runLDALeave1Out( featureVect, classLabels, numSamplesPerSubj )
 
 Inputs:
   featureVect: (dim x numSamples) matrix of data
   classLabels: class labels, valued [0,1,2] for visualization, leave empty to
       not visualize. Originally, class 1 for learn, 0 for didn't learn,
       2 for remove because unclear
   numSamplesPerSubj: as name implies, for splitting up the data into
      leave 1 subject out cross validation
   expLabels: optional input usually returned by 'getLeave1OutLabels'
      function which specifies how to partition the data for leave 1
      subject out cross validation
 
 Outputs:
   percCorrect: percentage correctly classified by leave 1 subj out CV
   w: w vector returned by all the well labeled data
   predictedVsTrue: cell containing the [predeicted ; true] labels of the
       data for each fold of the cross validation
 
 Note, class 1 for learn, 0 for didn't learn, 2 for remove because unclear

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % calculate leave 1 subject out accuracy for Bayes classifier with equal covariance
0002 %
0003 % syntax: [ percCorrect w ] = runLDALeave1Out( featureVect, classLabels, numSamplesPerSubj )
0004 %
0005 % Inputs:
0006 %   featureVect: (dim x numSamples) matrix of data
0007 %   classLabels: class labels, valued [0,1,2] for visualization, leave empty to
0008 %       not visualize. Originally, class 1 for learn, 0 for didn't learn,
0009 %       2 for remove because unclear
0010 %   numSamplesPerSubj: as name implies, for splitting up the data into
0011 %      leave 1 subject out cross validation
0012 %   expLabels: optional input usually returned by 'getLeave1OutLabels'
0013 %      function which specifies how to partition the data for leave 1
0014 %      subject out cross validation
0015 %
0016 % Outputs:
0017 %   percCorrect: percentage correctly classified by leave 1 subj out CV
0018 %   w: w vector returned by all the well labeled data
0019 %   predictedVsTrue: cell containing the [predeicted ; true] labels of the
0020 %       data for each fold of the cross validation
0021 %
0022 % Note, class 1 for learn, 0 for didn't learn, 2 for remove because unclear
0023 
0024 
0025 function  [ percCorrect w predictedVsTrue] = runLDALeave1Out( featureVect, classLabels, ...
0026                                     numSamplesPerSubj, expLabels )
0027 
0028 numSamples = size( featureVect,2);
0029 
0030 % center and scale variables to unit variance
0031 featureVect = featureVect - repmat( mean(featureVect,2), [1,numSamples] );
0032 featStdev = std( featureVect, 0, 2);
0033 featureVect( featStdev ~= 0,:) = featureVect( featStdev ~= 0,:)./repmat(featStdev(featStdev ~= 0), [1,numSamples]);
0034 % featureVect = featureVect./repmat( std( featureVect, 0, 2)+.001, [1,numSamples]);
0035 
0036 
0037 if nargin < 4 || isempty(expLabels)
0038     expLabels = getLeave1OutLabels( numSamples, numSamplesPerSubj);
0039 %     expLabels = balanceClasses( expLabels, classLabels );  % even out number samples in each class
0040 end
0041 numTrials = length(expLabels);
0042 
0043 classIndices = cell(2,1);
0044 numCorrect = 0;
0045 totalNumTest = 0;
0046 
0047 
0048 predictedVsTrue = cell( numTrials,1 );
0049 
0050 
0051 % leave one subject out cross validation
0052 for i1 = 1:numTrials  
0053     % training
0054     
0055     trainLabels = classLabels(:,expLabels(i1).train);
0056     trainFeatures = featureVect(:,expLabels(i1).train);
0057     trainFeatures( :, trainLabels==2) = [];
0058     trainLabels( :, trainLabels==2) = [];
0059     
0060     testLabels = classLabels(:,expLabels(i1).test);
0061     testFeatures = featureVect(:,expLabels(i1).test);
0062     testFeatures( :, testLabels==2) = [];
0063     testLabels( :, testLabels==2) = [];
0064     
0065     if ~isempty( testLabels)     
0066         classIndices{1} = find( trainLabels == 0); % non learning
0067         classIndices{2} = find( trainLabels == 1); % learning
0068                           
0069         meanVec = [ mean(trainFeatures(:,classIndices{1}),2), ...
0070                     mean(trainFeatures(:,classIndices{2}),2) ]; 
0071         Xtemp1 = [ trainFeatures(:,classIndices{1}) -  repmat( meanVec(:,1), [1,length(classIndices{1})])];                
0072         Xtemp2 = [ trainFeatures(:,classIndices{2}) -  repmat( meanVec(:,2), [1,length(classIndices{2})]) ];
0073         Sw = (Xtemp1*Xtemp1' + Xtemp2*Xtemp2')./length(trainLabels);
0074         
0075         %Sw = cov( trainFeatures');
0076         w = normc( pinv(Sw)*(meanVec(:,1) - meanVec(:,2)) );
0077 
0078         % testing
0079         projMeans = w'*meanVec;
0080         projectedData = w'*testFeatures;
0081         diffMatrix = sqrt((repmat(projMeans', [1,length(projectedData)]) ...
0082                         - repmat( projectedData,[length(projMeans),1]) ).^2);
0083         [c Idx ] = min( diffMatrix, [], 1);
0084         predictedLabels = Idx-1;  % class labels are 0,1 , in that order (of means)
0085 
0086         %     predictedLabels
0087         %     sum(testLabels == predictedLabels );
0088         tempNumCorrect = sum(testLabels == predictedLabels );
0089         numCorrect = numCorrect+ tempNumCorrect;
0090         numTest = length(predictedLabels);
0091         totalNumTest = totalNumTest+numTest;  
0092         fprintf( 'Accuracy = %.2f, (%d/%d), %d Guessed 1; Train # class 1: %d, class 0: %d. \n', ...
0093             tempNumCorrect/numTest, tempNumCorrect,numTest, length(find(predictedLabels==1)), ...
0094             length(find(trainLabels==1)), length(find(trainLabels==0)));
0095         
0096         predictedVsTrue{i1} = [ predictedLabels; testLabels];
0097         
0098     end
0099     
0100 end
0101 percCorrect = numCorrect/totalNumTest; %length(classLabels~=2);
0102 %----------------------------
0103 
0104 % find optimal w for all the data
0105 featureVect( :, classLabels==2) = [];
0106 classLabels( :, classLabels==2) = [];
0107 classIndices{1} = find( classLabels == 0); % non learning
0108 classIndices{2} = find( classLabels == 1); % learning
0109 
0110 meanVec = [ mean(featureVect(:,classIndices{1}),2), ...
0111                     mean(featureVect(:,classIndices{2}),2) ];                 
0112 Xtemp1 = [ featureVect(:,classIndices{1}) -  repmat( meanVec(:,1), [1,length(classIndices{1})])];                
0113 Xtemp2 = [ featureVect(:,classIndices{2}) -  repmat( meanVec(:,2), [1,length(classIndices{2})]) ];
0114 Sw = (Xtemp1*Xtemp1' + Xtemp2*Xtemp2')./length(classLabels);
0115         
0116 w = pinv(Sw)*(meanVec(:,1) - meanVec(:,2));
0117 w = -w;
0118 
0119 
0120

Generated on Wed 20-Jan-2016 11:50:43 by m2html © 2005