Sleipnir
tools/SVMmulticlass/SVMmulti.cpp
00001 #include <fstream>
00002 #include <iostream>
00003 #include <iterator>
00004 #include <vector>
00005 #include <queue>
00006 
00007 /*****************************************************************************
00008 * This file is provided under the Creative Commons Attribution 3.0 license.
00009 *
00010 * You are free to share, copy, distribute, transmit, or adapt this work
00011 * PROVIDED THAT you attribute the work to the authors listed below.
00012 * For more information, please see the following web page:
00013 * http://creativecommons.org/licenses/by/3.0/
00014 *
00015 * This file is a component of the Sleipnir library for functional genomics,
00016 * authored by:
00017 * Curtis Huttenhower (chuttenh@princeton.edu)
00018 * Mark Schroeder
00019 * Maria D. Chikina
00020 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00021 *
00022 * If you use this library, the included executable tools, or any related
00023 * code in your work, please cite the following publication:
00024 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00025 * Olga G. Troyanskaya.
00026 * "The Sleipnir library for computational functional genomics"
00027 *****************************************************************************/
00028 #include "stdafx.h"
00029 #include "cmdline.h"
00030 #include "statistics.h"
00031 
00032 using namespace SVMArc;
00033 //#include "../../extlib/svm_light/svm_light/kernel.h"
00034 
00035 
00036 
00037 
00038 
00039 int main(int iArgs, char** aszArgs) {
00040     gengetopt_args_info sArgs;
00041 
00042     CPCL PCL;
00043     SVMArc::CSVMSTRUCTMC SVM;
00044 
00045     size_t i, j, k , iGene, jGene;
00046     double bestscore;
00047 ;
00048     ifstream ifsm;
00049     if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
00050         cmdline_parser_print_help();
00051         return 1;
00052     }
00053 
00054     //Set Parameters
00055     SVM.SetLearningAlgorithm(sArgs.learning_algorithm_arg);
00056     SVM.SetVerbosity(sArgs.verbosity_arg);
00057     SVM.SetLossFunction(sArgs.loss_function_arg);
00058 
00059 
00060     if (sArgs.cross_validation_arg < 1){
00061         cerr << "cross_valid is <1. Must be set at least 1" << endl;
00062         return 1;
00063     }
00064     else if(sArgs.cross_validation_arg < 2){
00065         cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
00066     }
00067 
00068     SVM.SetTradeoff(sArgs.tradeoff_arg);
00069     if (sArgs.slack_flag)
00070         SVM.UseSlackRescaling();
00071     else
00072         SVM.UseMarginRescaling();
00073 
00074 
00075     if (!SVM.parms_check()) {
00076         cerr << "Parameter check not passed, see above errors" << endl;
00077         return 1;
00078     }
00079 
00080     //  cout << "there are " << vecLabels.size() << " labels processed" << endl;
00081     size_t iFile;
00082     vector<string> PCLs;
00083     if (sArgs.input_given) {
00084         if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
00085             cerr << "Could not open input PCL" << endl;
00086             return 1;
00087         }
00088     }
00089 
00090     //Read labels from file
00091     vector<SVMArc::SVMLabel> vecLabels;
00092     set<string> setLabeledGenes;
00093     if (sArgs.labels_given) {
00094         ifsm.clear();
00095         ifsm.open(sArgs.labels_arg);
00096         if (ifsm.is_open())
00097             vecLabels = SVM.ReadLabels(ifsm);
00098         else {
00099             cerr << "Could not read label file" << endl;
00100             return 1;
00101         }
00102         for (i = 0; i < vecLabels.size(); i++)
00103             setLabeledGenes.insert(vecLabels[i].GeneName);
00104     }
00105 
00106 
00107     //Training
00108     SAMPLE* pTrainSample;
00109     vector<SVMArc::SVMLabel> pTrainVector[sArgs.cross_validation_arg];
00110     vector<SVMArc::SVMLabel> pTestVector[sArgs.cross_validation_arg];
00111     vector<SVMArc::Result> AllResults;
00112     vector<SVMArc::Result> tmpAllResults;
00113 
00114     if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file
00115         pTrainSample = SVM.CreateSample(PCL, vecLabels);
00116         SVM.Learn(*pTrainSample);
00117         SVM.WriteModel(sArgs.model_arg);
00118     } else if (sArgs.model_given && sArgs.output_given) { //read model and classify all
00119         vector<SVMLabel> vecAllLabels;
00120 
00121         for (size_t i = 0; i < PCL.GetGenes(); i++)
00122             vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0));
00123 
00124         SVM.ReadModel(sArgs.model_arg);
00125         AllResults = SVM.Classify(PCL, vecAllLabels);
00126         ofstream ofsm;
00127         ofsm.open(sArgs.output_arg);
00128         if (ofsm.is_open())
00129             SVM.PrintResults(AllResults, ofsm);
00130         else {
00131             cerr << "Could not open output file" << endl;
00132         }
00133     } else if (sArgs.output_given && sArgs.labels_given) {
00134         //do learning and classifying with cross validation
00135         //set up training data
00136         if( sArgs.cross_validation_arg > 1){        
00137             for (i = 0; i < sArgs.cross_validation_arg; i++) {
00138                 pTestVector[i].reserve((size_t) vecLabels.size()
00139                     / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
00140                 pTrainVector[i].reserve((size_t) vecLabels.size()
00141                     / (sArgs.cross_validation_arg)
00142                     * (sArgs.cross_validation_arg - 1)
00143                     + sArgs.cross_validation_arg);
00144                 for (j = 0; j < vecLabels.size(); j++) {
00145                     if (j % sArgs.cross_validation_arg == i) {
00146                         pTestVector[i].push_back(vecLabels[j]);
00147                     } else {
00148                         pTrainVector[i].push_back((vecLabels[j]));
00149                     }
00150                 }
00151             }
00152         }
00153         else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted
00154 
00155             // no holdout so train is the same as test gene set
00156             pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00157             pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00158 
00159             for (j = 0; j < vecLabels.size(); j++) {
00160                 pTestVector[0].push_back(vecLabels[j]);           
00161                 pTrainVector[0].push_back(vecLabels[j]);            
00162             }
00163         }
00164         //set up training data done
00165 
00166         //set up validation data
00167         vector<SVMLabel> vec_allUnlabeledLabels;
00168         vector<Result> vec_allUnlabeledResults;
00169         vector<Result> vec_tmpUnlabeledResults;
00170         if (sArgs.all_flag) {
00171             vec_allUnlabeledLabels.reserve(PCL.GetGenes());
00172             vec_allUnlabeledResults.reserve(PCL.GetGenes());
00173             for (i = 0; i < PCL.GetGenes(); i++) {
00174                 if (setLabeledGenes.find(PCL.GetGene(i))
00175                     == setLabeledGenes.end()) {
00176                         vec_allUnlabeledLabels.push_back(
00177                             SVMLabel(PCL.GetGene(i), 0));
00178                         vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i)));
00179                 }
00180             }
00181         }
00182         //run once
00183         for (i = 0; i < sArgs.cross_validation_arg; i++) {
00184             pTrainSample = SVM.CreateSample(PCL,
00185                 pTrainVector[i]);
00186 
00187             cerr << "Cross Validation Trial " << i << endl;
00188             SVM.Learn(*pTrainSample);
00189             cerr << "Learned" << endl;
00190             tmpAllResults = SVM.Classify(PCL,   pTestVector[i]);
00191             cerr << "Classified " << tmpAllResults.size() << " examples"<< endl;
00192             AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end());
00193             tmpAllResults.resize(0);
00194             if (sArgs.all_flag) {
00195                 vec_tmpUnlabeledResults = SVM.Classify(
00196                     PCL, vec_allUnlabeledLabels);
00197                 
00198                 if(i == 0){
00199                     for (j = 0; j < vec_tmpUnlabeledResults.size(); j++){
00200                         vec_allUnlabeledResults[j].num_class = vec_tmpUnlabeledResults[j].num_class;
00201                         for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++)
00202                             vec_allUnlabeledResults[j].Scores.push_back(vec_tmpUnlabeledResults[j].Scores[k]);
00203                     }
00204                 }
00205                 else{
00206                     for (j = 0; j < vec_tmpUnlabeledResults.size(); j++)
00207                         for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++)
00208                             vec_allUnlabeledResults[j].Scores[k] += vec_tmpUnlabeledResults[j].Scores[k];
00209                 }
00210 
00211             }
00212             if (i > 0) {
00213                 SVMArc::CSVMSTRUCTMC::FreeSample(*pTrainSample);
00214             }
00215         }
00216 
00217         if (sArgs.all_flag) { //add the unlabeled results
00218             for (j = 0; j < vec_allUnlabeledResults.size(); j++)
00219                 for( k = 1; k <= vec_allUnlabeledResults[j].num_class; k++){
00220                     if(k==1){
00221                         vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg;
00222                         bestscore=vec_allUnlabeledResults[j].Scores[k];
00223                         vec_allUnlabeledResults[j].Value=k;
00224                     }else{
00225                         vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg;
00226                         if(vec_allUnlabeledResults[j].Scores[k] < bestscore){
00227                             bestscore = vec_allUnlabeledResults[j].Scores[k];
00228                             vec_allUnlabeledResults[j].Value=k;
00229                         }
00230                     }
00231                 }
00232 
00233             AllResults.insert(AllResults.end(),
00234                 vec_allUnlabeledResults.begin(),
00235                 vec_allUnlabeledResults.end());
00236         }
00237 
00238         ofstream ofsm;
00239         ofsm.clear();
00240         ofsm.open(sArgs.output_arg);
00241         SVM.PrintResults(AllResults, ofsm);
00242         return 0;
00243 
00244     } else {
00245         cerr << "More options are needed" << endl;
00246     }
00247 
00248 }
00249