Sleipnir
tools/LibSVMer/LibSVMer.cpp
00001 #include <fstream>
00002 
00003 #include <vector>
00004 #include <queue>
00005 
00006 /*****************************************************************************
00007  * This file is provided under the Creative Commons Attribution 3.0 license.
00008  *
00009  * You are free to share, copy, distribute, transmit, or adapt this work
00010  * PROVIDED THAT you attribute the work to the authors listed below.
00011  * For more information, please see the following web page:
00012  * http://creativecommons.org/licenses/by/3.0/
00013  *
00014  * This file is a component of the Sleipnir library for functional genomics,
00015  * authored by:
00016  * Curtis Huttenhower (chuttenh@princeton.edu)
00017  * Mark Schroeder
00018  * Maria D. Chikina
00019  * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00020  *
00021  * If you use this library, the included executable tools, or any related
00022  * code in your work, please cite the following publication:
00023  * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00024  * Olga G. Troyanskaya.
00025  * "The Sleipnir library for computational functional genomics"
00026  *****************************************************************************/
00027 #include "stdafx.h"
00028 #include "cmdline.h"
00029 #include "statistics.h"
00030 
00031 using namespace LIBSVM;
00032 
00033 vector<LIBSVM::SVMLabel> ReadLabels(ifstream & ifsm) {
00034 
00035   static const size_t c_iBuffer = 1024;
00036   char acBuffer[c_iBuffer];
00037   vector<string> vecstrTokens;
00038   vector<LIBSVM::SVMLabel> vecLabels;
00039   size_t numPositives, numNegatives;
00040   numPositives = numNegatives = 0;
00041   while (!ifsm.eof()) {
00042     ifsm.getline(acBuffer, c_iBuffer - 1);
00043     acBuffer[c_iBuffer - 1] = 0;
00044     vecstrTokens.clear();
00045     CMeta::Tokenize(acBuffer, vecstrTokens);
00046     if (vecstrTokens.empty())
00047       continue;
00048     if (vecstrTokens.size() != 2) {
00049       cerr << "Illegal label line (" << vecstrTokens.size() << "): "
00050         << acBuffer << endl;
00051       continue;
00052     }
00053     vecLabels.push_back(LIBSVM::SVMLabel(vecstrTokens[0], atof(
00054             vecstrTokens[1].c_str())));
00055     if (vecLabels.back().Target > 0)
00056       numPositives++;
00057     else
00058       numNegatives++;
00059   }
00060   return vecLabels;
00061 }
00062 
00063 
00064 struct SortResults {
00065 
00066   bool operator()(const LIBSVM::Result& rOne, const LIBSVM::Result & rTwo) const {
00067     return (rOne.Value > rTwo.Value);
00068   }
00069 };
00070 
00071 
00072 size_t PrintResults(vector<LIBSVM::Result> vecResults, ofstream & ofsm) {
00073   sort(vecResults.begin(), vecResults.end(), SortResults());
00074   int LabelVal;
00075   for (size_t i = 0; i < vecResults.size(); i++) {
00076     ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t'
00077       << vecResults[i].Value << endl;
00078   }
00079 };
00080 
00081 struct ParamStruct {
00082   vector<float> vecK, vecTradeoff;
00083   vector<size_t> vecLoss;
00084   vector<char*> vecNames;
00085 };
00086 
00087 int main(int iArgs, char** aszArgs) {
00088 
00089   gengetopt_args_info sArgs;
00090 
00091   CPCL PCL;//data
00092   LIBSVM::CLIBSVM SVM;//model
00093 
00094   size_t i, j, iGene, jGene;
00095   ifstream ifsm;
00096 
00097   if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
00098     cmdline_parser_print_help();
00099     return 1;
00100   }
00101 
00102   //Set model parameters
00103 
00104   if (sArgs.cross_validation_arg < 1){
00105     cerr << "cross_valid is <1. Must be set at least 1" << endl;
00106     return 1;
00107   }
00108   else if(sArgs.cross_validation_arg < 2){
00109     cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
00110     if(sArgs.num_cv_runs_arg > 1){
00111       cerr << "number of cv runs is > 1.  When no cv holdouts, must be set to 1." << endl;
00112       return 1;
00113     }
00114   }
00115 
00116   if (sArgs.num_cv_runs_arg < 1){
00117     cerr << "number of cv runs is < 1. Must be set at least 1" << endl;
00118     return 1;
00119   }
00120 
00121 
00122 
00123   SVM.SetTradeoff(sArgs.tradeoff_arg);
00124   SVM.SetNu(sArgs.nu_arg);
00125   SVM.SetSVMType(sArgs.svm_type_arg);
00126   SVM.SetBalance(sArgs.balance_flag);
00127 
00128   if (!SVM.parms_check()) {
00129     cerr << "Sanity check failed, see above errors" << endl;
00130     return 1;
00131   }
00132 
00133   //TODO: allow multiple PCL files
00134   //size_t iFile; //TODO
00135   // vector<string> PCLs; //TODO
00136 
00137   //check data file
00138   if (sArgs.input_given) {
00139     if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
00140       cerr << "Could not open input PCL" << endl;
00141       return 1;
00142     }
00143   }
00144 
00145   //read label files
00146   vector<LIBSVM::SVMLabel> vecLabels;
00147   set<string> setLabeledGenes;
00148   if (sArgs.labels_given) {
00149     ifsm.clear();
00150     ifsm.open(sArgs.labels_arg);
00151     if (ifsm.is_open())
00152       vecLabels = ReadLabels(ifsm);
00153     else {
00154       cerr << "Could not read label file" << endl;
00155       return 1;
00156     }
00157     for (i = 0; i < vecLabels.size(); i++)
00158       setLabeledGenes.insert(vecLabels[i].GeneName);
00159   }
00160   
00161   if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file
00162     //TODO
00163     cerr << "not yet implemented: learn once and write to file" << endl;
00164     /*
00165     pTrainSample = CLIBSVM::CreateSample(PCL, vecLabels);
00166     SVM.Learn(*pTrainSample);
00167     SVM.WriteModel(sArgs.model_arg);
00168     */
00169 
00170   } else if (sArgs.model_given && sArgs.output_given) { //read model and classify all
00171     //TODO
00172     cerr << "not yet implemetned: read model and classify all" << endl;
00173     /*
00174     vector<SVMLabel> vecAllLabels;
00175     for (size_t i = 0; i < PCL.GetGenes(); i++)
00176       vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0));
00177 
00178     SVM.ReadModel(sArgs.model_arg);
00179     AllResults = SVM.Classify(PCL, vecAllLabels);
00180     ofstream ofsm;
00181     ofsm.open(sArgs.output_arg);
00182     if (ofsm.is_open())
00183       PrintResults(AllResults, ofsm);
00184     else {
00185       cerr << "Could not open output file" << endl;
00186     }
00187     */
00188 
00189   } else if (sArgs.output_given && sArgs.labels_given) {
00190  
00191     LIBSVM::SAMPLE* pTrainSample;//sampled data
00192     size_t numSample;//number of sampling
00193 
00194     numSample = sArgs.cross_validation_arg * sArgs.num_cv_runs_arg;
00195   
00196     vector<LIBSVM::SVMLabel> pTrainVector[numSample];
00197     vector<LIBSVM::SVMLabel> pTestVector[numSample];
00198     vector<LIBSVM::Result> AllResults;
00199     vector<LIBSVM::Result> testResults;
00200 
00201     //set train and test label vectors
00202     //
00203     if( sArgs.cross_validation_arg > 1 && sArgs.num_cv_runs_arg >= 1 ){
00204       //do learning and classifying with cross validation
00205       //
00206       size_t ii, index;
00207 
00208       for (ii = 0; ii < sArgs.num_cv_runs_arg; ii++) {
00209         if(ii > 0)
00210           std::random_shuffle(vecLabels.begin(), vecLabels.end());
00211 
00212         for (i = 0; i < sArgs.cross_validation_arg; i++) {                  
00213           index = sArgs.cross_validation_arg * ii + i;
00214           pTestVector[index].reserve((size_t) vecLabels.size()
00215               / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
00216           pTrainVector[index].reserve((size_t) vecLabels.size()
00217               / (sArgs.cross_validation_arg)
00218               * (sArgs.cross_validation_arg - 1)
00219               + sArgs.cross_validation_arg);
00220           for (j = 0; j < vecLabels.size(); j++) {
00221             if (j % sArgs.cross_validation_arg == i) {
00222               pTestVector[index].push_back(vecLabels[j]);
00223             } else {
00224               pTrainVector[index].push_back(vecLabels[j]);
00225             }
00226           }
00227         }
00228 
00229       }
00230     }  
00231     else{ 
00232       // if you have less than 2 fold cross, no cross validation is done, 
00233       // all train genes are used and predicted
00234       //
00235       cerr << "no holdout so train is the same as test" << endl;
00236       pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00237       pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00238 
00239       for (j = 0; j < vecLabels.size(); j++) {
00240         pTestVector[0].push_back(vecLabels[j]);           
00241         pTrainVector[0].push_back(vecLabels[j]);            
00242       }
00243     }
00244 
00245     //if want to make predictions for genes (row) with no label information
00246     //
00247     vector<SVMLabel> vec_allUnlabeledLabels;
00248     vector<Result> vec_allUnlabeledResults;
00249     vector<Result> tmpUnlabeledResults;
00250     if (sArgs.all_flag) {
00251       vec_allUnlabeledLabels.reserve(PCL.GetGenes());
00252       vec_allUnlabeledResults.reserve(PCL.GetGenes());
00253       for (i = 0; i < PCL.GetGenes(); i++) {
00254         if (setLabeledGenes.find(PCL.GetGene(i))
00255             == setLabeledGenes.end()) { // if gene with no label information
00256 
00257           vec_allUnlabeledLabels.push_back(SVMLabel(PCL.GetGene(i), 0));
00258           vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i)));
00259         }
00260       }
00261     }
00262 
00263     bool added;//flag for merging testResults and AllResults
00264 
00265     //for each sample
00266     for (i = 0; i < numSample; i++) {
00267       pTrainSample = LIBSVM::CLIBSVM::CreateSample(PCL, pTrainVector[i]);
00268       cerr << "Trial " << i << endl;
00269 
00270       SVM.Learn(*pTrainSample);
00271       cerr << "Learned" << endl;
00272 
00273       testResults = SVM.Classify(PCL, pTestVector[i]);
00274       cerr << "Classified " << testResults.size() << " test examples" << endl;
00275 
00276       // merge testResults and AllResults
00277       // TODO: make more efficent
00278       for(std::vector<LIBSVM::Result>::iterator it = testResults.begin() ; 
00279           it != testResults.end() ; it ++){
00280 
00281         added = false;
00282         for(std::vector<LIBSVM::Result>::iterator ita = AllResults.begin() ; 
00283             ita != AllResults.end() ; ita ++){
00284 
00285           if ( (*it).GeneName.compare((*ita).GeneName) == 0 ){
00286 
00287             (*ita).Value += (*it).Value;
00288             added = true;
00289             break;
00290           }
00291 
00292         }
00293 
00294         if(!added)
00295           AllResults.push_back((*it));
00296 
00297       }
00298       testResults.clear();
00299 
00300       // classify genes with no label information
00301       if (sArgs.all_flag) {
00302         tmpUnlabeledResults = SVM.Classify(
00303             PCL, vec_allUnlabeledLabels);//make predictions
00304         for (j = 0; j < tmpUnlabeledResults.size(); j++)
00305           vec_allUnlabeledResults[j].Value
00306             += tmpUnlabeledResults[j].Value;
00307       }
00308 
00309       if (i > 0) {
00310         //LIBSVM::CLIBSVM::FreeSample(*pTrainSample);
00311         free(pTrainSample);
00312       }
00313 
00314       //mem = CMeta::GetMemoryUsage();
00315       
00316       cerr << "end of trail" << endl;
00317 
00318     }
00319 
00320     // average results (svm outputs) from multiple cv runs
00321     for(std::vector<LIBSVM::Result>::iterator it = AllResults.begin();
00322         it != AllResults.end(); ++ it){
00323       (*it).Value /= sArgs.num_cv_runs_arg;
00324 
00325     }
00326 
00327     if (sArgs.all_flag) { //add the unlabeled results
00328       for (j = 0; j < vec_allUnlabeledResults.size(); j++)
00329         vec_allUnlabeledResults[j].Value
00330           /= (sArgs.cross_validation_arg * sArgs.num_cv_runs_arg);
00331       AllResults.insert(AllResults.end(),
00332           vec_allUnlabeledResults.begin(),
00333           vec_allUnlabeledResults.end());
00334     }
00335 
00336     ofstream ofsm;
00337     ofsm.clear();
00338     ofsm.open(sArgs.output_arg);
00339     PrintResults(AllResults, ofsm);
00340     return 0;
00341 
00342   } else {
00343     cerr << "More options are needed" << endl;
00344   }
00345 
00346 }
00347