Sleipnir
tools/SVMperfer/SVMperfer.cpp
00001 #include <fstream>
00002 
00003 #include <vector>
00004 #include <queue>
00005 
00006 /*****************************************************************************
00007  * This file is provided under the Creative Commons Attribution 3.0 license.
00008  *
00009  * You are free to share, copy, distribute, transmit, or adapt this work
00010  * PROVIDED THAT you attribute the work to the authors listed below.
00011  * For more information, please see the following web page:
00012  * http://creativecommons.org/licenses/by/3.0/
00013  *
00014  * This file is a component of the Sleipnir library for functional genomics,
00015  * authored by:
00016  * Curtis Huttenhower (chuttenh@princeton.edu)
00017  * Mark Schroeder
00018  * Maria D. Chikina
00019  * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00020  *
00021  * If you use this library, the included executable tools, or any related
00022  * code in your work, please cite the following publication:
00023  * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00024  * Olga G. Troyanskaya.
00025  * "The Sleipnir library for computational functional genomics"
00026  *****************************************************************************/
00027 #include "stdafx.h"
00028 #include "cmdline.h"
00029 #include "statistics.h"
00030 
00031 using namespace SVMLight;
00032 //#include "../../extlib/svm_light/svm_light/kernel.h"
00033 
00034 inline bool file_exists (const std::string& name) {
00035     struct stat buffer;   
00036     return (stat (name.c_str(), &buffer) == 0); 
00037 }
00038 
00039 vector< pair< string, string > > ReadLabelList(ifstream & ifsm, string output_prefix) {
00040   static const size_t c_iBuffer = 1024;
00041   char acBuffer[c_iBuffer];
00042   vector<string> vecstrTokens;
00043   vector< pair < string, string > > inout;
00044   while (!ifsm.eof()) {
00045     ifsm.getline(acBuffer, c_iBuffer - 1);
00046     acBuffer[c_iBuffer - 1] = 0;
00047     vecstrTokens.clear();
00048     CMeta::Tokenize(acBuffer, vecstrTokens);
00049     if (vecstrTokens.empty())
00050       continue;
00051     if (vecstrTokens.size() != 2) {
00052       cerr << "Illegal inout line (" << vecstrTokens.size() << "): "
00053         << acBuffer << endl;
00054       continue;
00055     }
00056     
00057     if( file_exists( output_prefix + "/" + vecstrTokens[1] ) ){
00058       continue;
00059     }
00060     
00061 
00062     //cout << file_exists( vecstrTokens[1] ) << endl;
00063 
00064     inout.push_back( make_pair( vecstrTokens[0], vecstrTokens[1] ) );
00065   }
00066   cout << inout.size() << " number of label files." << endl;
00067   return inout;
00068 
00069 }
00070 
00071 vector<SVMLight::SVMLabel> ReadLabels(ifstream & ifsm) {
00072 
00073   static const size_t c_iBuffer = 1024;
00074   char acBuffer[c_iBuffer];
00075   vector<string> vecstrTokens;
00076   vector<SVMLight::SVMLabel> vecLabels;
00077   size_t numPositives, numNegatives;
00078   numPositives = numNegatives = 0;
00079   while (!ifsm.eof()) {
00080     ifsm.getline(acBuffer, c_iBuffer - 1);
00081     acBuffer[c_iBuffer - 1] = 0;
00082     vecstrTokens.clear();
00083     CMeta::Tokenize(acBuffer, vecstrTokens);
00084     if (vecstrTokens.empty())
00085       continue;
00086     if (vecstrTokens.size() != 2) {
00087       cerr << "Illegal label line (" << vecstrTokens.size() << "): "
00088         << acBuffer << endl;
00089       continue;
00090     }
00091     //cout << vecstrTokens[0] << endl;
00092     //cout << vecstrTokens[1] << endl;
00093 
00094 
00095     vecLabels.push_back(SVMLight::SVMLabel(vecstrTokens[0], atof(
00096             vecstrTokens[1].c_str())));
00097     if (vecLabels.back().Target > 0)
00098       numPositives++;
00099     else
00100       numNegatives++;
00101   }
00102 
00103   cout << numPositives << endl;
00104   cout << numNegatives << endl;
00105 
00106   return vecLabels;
00107 }
00108 
00109 struct SortResults {
00110 
00111   bool operator()(const SVMLight::Result& rOne, const SVMLight::Result & rTwo) const {
00112     return (rOne.Value > rTwo.Value);
00113   }
00114 };
00115 
00116 size_t PrintResults(vector<SVMLight::Result> vecResults, ofstream & ofsm) {
00117   sort(vecResults.begin(), vecResults.end(), SortResults());
00118   int LabelVal;
00119   for (size_t i = 0; i < vecResults.size(); i++) {
00120     ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t'
00121       << vecResults[i].Value << endl;
00122   }
00123 }
00124 ;
00125 
00126 struct ParamStruct {
00127   vector<float> vecK, vecTradeoff;
00128   vector<size_t> vecLoss;
00129   vector<char*> vecNames;
00130 };
00131 
00132 ParamStruct ReadParamsFromFile(ifstream& ifsm, string outFile) {
00133   static const size_t c_iBuffer = 1024;
00134   char acBuffer[c_iBuffer];
00135   char* nameBuffer;
00136   vector<string> vecstrTokens;
00137   size_t extPlace;
00138   string Ext, FileName;
00139   if ((extPlace = outFile.find_first_of(".")) != string::npos) {
00140     FileName = outFile.substr(0, extPlace);
00141     Ext = outFile.substr(extPlace, outFile.size());
00142   } else {
00143     FileName = outFile;
00144     Ext = "";
00145   }
00146   ParamStruct PStruct;
00147   size_t index = 0;
00148   while (!ifsm.eof()) {
00149     ifsm.getline(acBuffer, c_iBuffer - 1);
00150     acBuffer[c_iBuffer - 1] = 0;
00151     vecstrTokens.clear();
00152     CMeta::Tokenize(acBuffer, vecstrTokens);
00153     if (vecstrTokens.empty())
00154       continue;
00155     if (vecstrTokens.size() != 3) {
00156       cerr << "Illegal params line (" << vecstrTokens.size() << "): "
00157         << acBuffer << endl;
00158       continue;
00159     }
00160     if (acBuffer[0] == '#') {
00161       cerr << "skipping " << acBuffer << endl;
00162     } else {
00163       PStruct.vecLoss.push_back(atoi(vecstrTokens[0].c_str()));
00164       PStruct.vecTradeoff.push_back(atof(vecstrTokens[1].c_str()));
00165       PStruct.vecK.push_back(atof(vecstrTokens[2].c_str()));
00166       PStruct.vecNames.push_back(new char[c_iBuffer]);
00167       if (PStruct.vecLoss[index] == 4 || PStruct.vecLoss[index] == 5)
00168         sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f_k%4.3f%s",
00169             FileName.c_str(), PStruct.vecLoss[index],
00170             PStruct.vecTradeoff[index], PStruct.vecK[index],
00171             Ext.c_str());
00172       else
00173         sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f%s",
00174             FileName.c_str(), PStruct.vecLoss[index],
00175             PStruct.vecTradeoff[index], Ext.c_str());
00176       index++;
00177     }
00178 
00179   }
00180   return PStruct;
00181 }
00182 
00183 int main(int iArgs, char** aszArgs) {
00184   gengetopt_args_info sArgs;
00185 
00186   CPCL PCL;
00187   SVMLight::CSVMPERF SVM;
00188 
00189   size_t i, j, iGene, jGene;
00190   ifstream ifsm, iifsm;
00191 
00192   if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
00193     cmdline_parser_print_help();
00194     return 1;
00195   }
00196   SVM.SetVerbosity(sArgs.verbosity_arg);
00197   SVM.SetLossFunction(sArgs.error_function_arg);
00198   if (sArgs.k_value_arg > 1) {
00199     cerr << "k_value is >1. Setting default 0.5" << endl;
00200     SVM.SetPrecisionFraction(0.5);
00201   } else if (sArgs.k_value_arg <= 0) {
00202     cerr << "k_value is <=0. Setting default 0.5" << endl;
00203     SVM.SetPrecisionFraction(0.5);
00204   } else {
00205     SVM.SetPrecisionFraction(sArgs.k_value_arg);
00206   }
00207 
00208 
00209   if (sArgs.cross_validation_arg < 1){
00210     cerr << "cross_valid is <1. Must be set at least 1" << endl;
00211     return 1;
00212   }
00213   else if(sArgs.cross_validation_arg < 2){
00214     cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
00215   }
00216 
00217   SVM.SetTradeoff(sArgs.tradeoff_arg);
00218   if (sArgs.slack_flag)
00219     SVM.UseSlackRescaling();
00220   else
00221     SVM.UseMarginRescaling();
00222 
00223 
00224   if (!SVM.parms_check()) {
00225     cerr << "Sanity check failed, see above errors" << endl;
00226     return 1;
00227   }
00228 
00229   if (!sArgs.output_given){
00230     cerr << "output prefix not provided" << endl;
00231     return 1;
00232   }
00233   
00234   string output_prefix(sArgs.output_arg);
00235 
00236   //  cout << "there are " << vecLabels.size() << " labels processed" << endl;
00237   size_t iFile;
00238   vector<string> PCLs;
00239   if (sArgs.input_given) {
00240     if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
00241       cerr << "Could not open input PCL" << endl;
00242       return 1;
00243     }
00244   }
00245 
00246 
00247   vector< pair < string, string > > vecLabelLists;
00248   if (sArgs.labels_given) {
00249     ifsm.clear();
00250     ifsm.open(sArgs.labels_arg);
00251     if (ifsm.is_open())
00252       vecLabelLists = ReadLabelList(ifsm, output_prefix);
00253     else {
00254       cerr << "Could not read label list" << endl;
00255       return 1;
00256     }
00257     ifsm.close();
00258   }else{
00259     cerr << "list of labels not given" << endl;
00260     return 1;
00261     //  if (sArgs.labels_given) {
00262     //    vecLabelLists.push_back(pair(sArgs.labels_arg,sArgs.output_arg))
00263     //  }
00264   }
00265   size_t k;
00266   string labels_fn;
00267   string output_fn;
00268 
00269   
00270     SVMLight::SAMPLE* pTrainSample;
00271     vector<SVMLight::Result> AllResults;
00272     vector<SVMLight::Result> tmpAllResults;
00273     vector<SVMLight::SVMLabel> pTrainVector[sArgs.cross_validation_arg];
00274     vector<SVMLight::SVMLabel> pTestVector[sArgs.cross_validation_arg];
00275     vector<SVMLight::SVMLabel> vecLabels;
00276  
00277     string out_fn;
00278 
00279   for(k = 0; k < vecLabelLists.size(); k ++){
00280     labels_fn = vecLabelLists[k].first;
00281     output_fn = vecLabelLists[k].second;
00282 
00283     cout << labels_fn << endl;
00284     cout << output_fn << endl;
00285     
00286     vecLabels.clear();
00287 
00288     ifsm.clear();
00289     ifsm.open(labels_fn.c_str());
00290     if (ifsm.is_open())
00291       vecLabels = ReadLabels(ifsm);
00292     else {
00293       cerr << "Could not read label file" << endl;
00294       return 1;
00295     }
00296     ifsm.close();
00297 
00298     cout << "finished reading labels." << endl;
00299 
00300 
00301     //do learning and classifying with cross validation
00302     if( sArgs.cross_validation_arg > 1){        
00303       for (i = 0; i < sArgs.cross_validation_arg; i++) {
00304 
00305         pTestVector[i].clear();
00306         pTrainVector[i].clear();
00307 
00308         pTestVector[i].reserve((size_t) vecLabels.size()
00309             / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
00310         pTrainVector[i].reserve((size_t) vecLabels.size()
00311             / (sArgs.cross_validation_arg)
00312             * (sArgs.cross_validation_arg - 1)
00313             + sArgs.cross_validation_arg);
00314         for (j = 0; j < vecLabels.size(); j++) {
00315           if (j % sArgs.cross_validation_arg == i) {
00316             pTestVector[i].push_back(vecLabels[j]);
00317           } else {
00318             pTrainVector[i].push_back((vecLabels[j]));
00319           }
00320         }
00321       }
00322     }
00323     else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted
00324 
00325       // no holdout so train is the same as test gene set
00326       pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00327       pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00328 
00329       for (j = 0; j < vecLabels.size(); j++) {
00330         pTestVector[0].push_back(vecLabels[j]);           
00331         pTrainVector[0].push_back(vecLabels[j]);            
00332       }
00333     }
00334 
00335     for (i = 0; i < sArgs.cross_validation_arg; i++) {
00336       pTrainSample = SVMLight::CSVMPERF::CreateSample(PCL,
00337           pTrainVector[i]);
00338 
00339       cerr << "Cross Validation Trial " << i << endl;
00340 
00341       SVM.Learn(*pTrainSample);
00342       cerr << "Learned" << endl;
00343       tmpAllResults = SVM.Classify(PCL,
00344           pTestVector[i]);
00345       cerr << "Classified " << tmpAllResults.size() << " examples"
00346         << endl;
00347       AllResults.insert(AllResults.end(), tmpAllResults.begin(),
00348           tmpAllResults.end());
00349       tmpAllResults.resize(0);
00350 
00351       if (i > 0) {
00352         SVMLight::CSVMPERF::FreeSample(*pTrainSample);
00353       }
00354     }
00355 
00356     ofstream ofsm;
00357     ofsm.clear();
00358     out_fn = output_prefix + "/" + output_fn;
00359     ofsm.open(out_fn.c_str());
00360     PrintResults(AllResults, ofsm);
00361     cout << "printed: " << output_fn << endl;
00362 
00363  
00364     delete[] pTrainSample;
00365     AllResults.clear();
00366     tmpAllResults.clear();
00367     vecLabels.clear();
00368 
00369 
00370 
00371   } 
00372 }
00373