Home > GVSToolbox > sortVariablesNB.m

sortVariablesNB

PURPOSE ^

Naive bayes ranking of variables (using CV accuracy criterion)

SYNOPSIS ^

function [ bestVariables, bestToWorst, accuracy ] = sortVariablesNB( featureVect, classLabels,numSamplesPerSubj, topVarsToKeep )

DESCRIPTION ^

 Naive bayes ranking of variables (using CV accuracy criterion)
 
 syntax: [ bestVariables bestToWorst accuracy ] = sortVariablesNB( featureVect, classLabels, ...
                 numSamplesPerSubj, topVarsToKeep )
 
 Inputs:
   featureVect: all the the data samples in (dim x numSamples)
   classLabels: all class labels (0 for not learned, 1 for learned, 2 unsure.  
      The ones labeled class 2 will not be used.
   numSamplesPerSubj: featureVect assumed to be such that each subject has
      some number of samples (specified by each entry of numSamplesPerSubj),
      and they are grouped consecutively. This paramter is needed to do 
      the leave 1 subject out cross  validation.
   topVarsToKeep: index of number of best variables to return
 
 Outputs:
   bestVariables: indices of top variables to separate the classes
   bestToWorst: index ordering all the variables (not just the top)
   accuracy:  associated number correct for those indices

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 % Naive bayes ranking of variables (using CV accuracy criterion)
0002 %
0003 % syntax: [ bestVariables bestToWorst accuracy ] = sortVariablesNB( featureVect, classLabels, ...
0004 %                 numSamplesPerSubj, topVarsToKeep )
0005 %
0006 % Inputs:
0007 %   featureVect: all the the data samples in (dim x numSamples)
0008 %   classLabels: all class labels (0 for not learned, 1 for learned, 2 unsure.
0009 %      The ones labeled class 2 will not be used.
0010 %   numSamplesPerSubj: featureVect assumed to be such that each subject has
0011 %      some number of samples (specified by each entry of numSamplesPerSubj),
0012 %      and they are grouped consecutively. This paramter is needed to do
0013 %      the leave 1 subject out cross  validation.
0014 %   topVarsToKeep: index of number of best variables to return
0015 %
0016 % Outputs:
0017 %   bestVariables: indices of top variables to separate the classes
0018 %   bestToWorst: index ordering all the variables (not just the top)
0019 %   accuracy:  associated number correct for those indices
0020 %
0021 
0022 function [ bestVariables, bestToWorst, accuracy ] = sortVariablesNB( featureVect, classLabels, ...
0023                 numSamplesPerSubj, topVarsToKeep )
0024 
0025 if nargin < 4 || isempty( topVarsToKeep)
0026     topVarsToKeep = 10;
0027 end
0028 
0029 % leave one subject out cross validation
0030 [ dim, numSamples] = size( featureVect);
0031 expLabels = getLeave1OutLabels( numSamples, numSamplesPerSubj);
0032 numTrials = length(expLabels);
0033 accuracy = zeros( dim,1); %,numTrials);
0034 
0035 % center and scale variables to unit variance
0036 featureVect = featureVect - repmat( mean(featureVect,2), [1,numSamples] );
0037 featStdev = std( featureVect, 0, 2);
0038 featureVect( featStdev ~= 0,:) = featureVect( featStdev ~= 0,:)./repmat(featStdev(featStdev ~= 0), [1,numSamples]);
0039 % featureVect = featureVect./repmat( std( featureVect, 0, 2)+.001, [1,numSamples]);
0040 
0041 
0042 for i1 = 1:numTrials 
0043     
0044     trainLabels = classLabels(:,expLabels(i1).train);
0045     trainFeatures = featureVect(:,expLabels(i1).train);
0046     trainFeatures( :, trainLabels==2) = [];
0047     trainLabels( :, trainLabels==2) = [];
0048     
0049     testLabels = classLabels(:,expLabels(i1).test);
0050     testFeatures = featureVect(:,expLabels(i1).test);
0051     testFeatures( :, testLabels==2) = [];
0052     testLabels( :, testLabels==2) = [];
0053     
0054     % do the tests on each feature
0055     for i2 = 1:dim
0056         if var( trainFeatures(i2,:)) < 1e-5 || ...
0057             var( trainFeatures(i2,trainLabels==0))< 1e-5 || ...
0058             var( trainFeatures(i2,trainLabels==1))< 1e-5 ,        
0059             
0060                 % do nothing
0061                 %accuracy(i2,i1) = 0;
0062         else
0063 %             var( trainFeatures(i2,:))
0064 %             var( trainFeatures(i2,trainLabels==0))
0065 %             var( trainFeatures(i2,trainLabels==1))
0066             
0067             nb = NaiveBayes.fit(trainFeatures(i2,:)', trainLabels', 'Prior', 'uniform');
0068             estimate = nb.predict(testFeatures(i2,:)');          
0069             accuracy(i2) = accuracy(i2) + sum( estimate == testLabels'); %/length(testLabels);
0070         end
0071     end
0072 %      training is an N-by-D numeric matrix of training data. Rows of training
0073 %      correspond to observations; columns correspond to features. class is a
0074 %      classing variable for training (see Grouped Data) taking K distinct levels.
0075 %      Each element of class defines which class the corresponding row of training
0076 %      belongs to. training and class must have the same number of rows.
0077 
0078 % 'Prior' – The prior probabilities for the classes, specified as one of the following:
0079 %
0080 % 'empirical' (default)    fit estimates the prior probabilities from the relative frequencies of the classes in training.
0081 % 'uniform'
0082 
0083 end
0084 
0085 % % % % sort variables in order best to worst
0086 % % % [ accuracy bestToWorst] = sort( accuracy, 1, 'descend');
0087 % % % %feature selection
0088 % % % bestVariables = bestToWorst(1:topVarsToSearch,:);
0089 % % % bestVariables = bestVariables(:);
0090 % % % uniqVars = unique( sort(bestVariables) );
0091 % % % lengthList = zeros( length(uniqVars),1);
0092 % % % for i1 = 1:length(uniqVars)
0093 % % %    lengthList(i1) = length( find( bestVariables == uniqVars(i1)));
0094 % % % end
0095 % % % [ val idx] = sort( lengthList , 'descend');
0096 
0097 
0098 % sort variables in order best to worst
0099 [ accuracy bestToWorst] = sort( accuracy, 1,'descend');
0100 
0101 % remove redundancies
0102 unqIdx = findRedundancies( featureVect( bestToWorst,:) );
0103 bestToWorst = bestToWorst(unqIdx);
0104 
0105 bestVariables = bestToWorst(1:min(topVarsToKeep, length(bestToWorst) ));

Generated on Tue 01-Jul-2014 12:35:04 by m2html © 2005