Sleipnir
tools/SVMhierarchy/SVMhierarchy.cpp
00001 #include <fstream>
00002 #include <iostream>
00003 #include <iterator>
00004 #include <vector>
00005 #include <queue>
00006 
00007 /*****************************************************************************
00008 * This file is provided under the Creative Commons Attribution 3.0 license.
00009 *
00010 * You are free to share, copy, distribute, transmit, or adapt this work
00011 * PROVIDED THAT you attribute the work to the authors listed below.
00012 * For more information, please see the following web page:
00013 * http://creativecommons.org/licenses/by/3.0/
00014 *
00015 * This file is a component of the Sleipnir library for functional genomics,
00016 * authored by:
00017 * Curtis Huttenhower (chuttenh@princeton.edu)
00018 * Mark Schroeder
00019 * Maria D. Chikina
00020 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00021 *
00022 * If you use this library, the included executable tools, or any related
00023 * code in your work, please cite the following publication:
00024 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00025 * Olga G. Troyanskaya.
00026 * "The Sleipnir library for computational functional genomics"
00027 *****************************************************************************/
00028 #include "stdafx.h"
00029 #include "cmdline.h"
00030 #include "statistics.h"
00031 
00032 using namespace SVMArc;
00033 //#include "../../extlib/svm_light/svm_light/kernel.h"
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00041 int main(int iArgs, char** aszArgs) {
00042     gengetopt_args_info sArgs;
00043 
00044     CPCL PCL;
00045     CDat DAT;
00046     SVMArc::CSVMSTRUCTTREE SVM;
00047 
00048     size_t i, j, k , iGene, jGene;
00049     double bestscore;
00050     ;
00051     ifstream ifsm;
00052     if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
00053         cmdline_parser_print_help();
00054         return 1;
00055     }
00056 
00057     //Set Parameters
00058     SVM.SetLearningAlgorithm(sArgs.learning_algorithm_arg);
00059     SVM.SetVerbosity(sArgs.verbosity_arg);
00060     SVM.SetLossFunction(sArgs.loss_function_arg);
00061     SVM.SetNThreads(sArgs.threads_arg);
00062     cerr << "SetLossFunction: " <<sArgs.loss_function_arg<< endl;
00063 
00064     if (sArgs.cross_validation_arg < 1){
00065         cerr << "cross_valid is <1. Must be set at least 1" << endl;
00066         return 1;
00067     }
00068     else if(sArgs.cross_validation_arg < 2){
00069         cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
00070     }
00071 
00072     SVM.SetTradeoff(sArgs.tradeoff_arg);
00073     SVM.SetEpsilon(sArgs.epsilon_arg);
00074     if (sArgs.slack_flag)
00075         SVM.UseSlackRescaling();
00076     else
00077         SVM.UseMarginRescaling();
00078 
00079     SVM.ReadOntology(sArgs.ontoparam_arg); // Read Ontology File
00080 
00081     if (!SVM.parms_check()) {
00082         cerr << "Parameter check not passed, see above errors" << endl;
00083         return 1;
00084     }
00085 
00086     //Read labels from file
00087     vector<SVMArc::SVMLabel> vecLabels;
00088     set<string> setLabeledGenes;
00089     if (sArgs.labels_given) {
00090         ifsm.clear();
00091         ifsm.open(sArgs.labels_arg);
00092         if (ifsm.is_open())
00093             vecLabels = SVM.ReadLabels(ifsm);
00094         else {
00095             cerr << "Could not read label file" << endl;
00096             exit(1);
00097         }
00098         for (i = 0; i < vecLabels.size(); i++)
00099             setLabeledGenes.insert(vecLabels[i].GeneName);
00100         cerr << "Read labels from file" << endl;
00101         SVM.InitializeLikAfterReadLabels();
00102     }
00103 
00104 
00105     
00106 
00107 
00108     //  cout << "there are " << vecLabels.size() << " labels processed" << endl;
00109     size_t iFile;
00110     vector<string> PCLs;
00111 
00112     if (sArgs.input_given) {
00113         cerr << "Loading PCL file" << endl;
00114         if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
00115             cerr << "Could not open input PCL" << endl;
00116             exit(1);
00117         }
00118         cerr << "PCL file Loaded" << endl;
00119     }
00120     //else if (sArgs.dab_input_given){
00121     //          cerr << "Loading DAT/DAB file" << endl;
00122     //  if (!DAT.Open(sArgs.input_arg, !!sArgs.mmap_flag)) {
00123     //      cerr << "Could not open input DAT/DAB file" << endl;
00124     //      exit(1);
00125     //  }
00126     //}
00127     //
00128     //
00129 
00130 
00131 
00132 
00133     //Training
00134     SAMPLE* pTrainSample;
00135     vector<SVMArc::SVMLabel> pTrainVector[sArgs.cross_validation_arg];
00136     vector<SVMArc::SVMLabel> pTestVector[sArgs.cross_validation_arg];
00137     vector<SVMArc::Result> AllResults;
00138     vector<SVMArc::Result> tmpAllResults;
00139 
00140     if (sArgs.model_given && sArgs.output_given && (!sArgs.labels_given) ) { 
00141         if(!sArgs.test_labels_given){//read model and classify all
00142             vector<SVMLabel> vecAllLabels;
00143 
00144             for (size_t i = 0; i < PCL.GetGenes(); i++)
00145                 vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0));
00146 
00147             SVM.ReadModel(sArgs.model_arg);
00148             cerr << "Model Loaded" << endl;
00149 
00150             AllResults = SVM.Classify(PCL, vecAllLabels);
00151             ofstream ofsm;
00152             ofsm.open(sArgs.output_arg);
00153             if (ofsm.is_open())
00154                 SVM.PrintResults(AllResults, ofsm);
00155             else {
00156                 cerr << "Could not open output file" << endl;
00157                 
00158             }
00159         }
00160         else//read model and classify only test examples
00161         {
00162             ifsm.clear();
00163             ifsm.open(sArgs.test_labels_arg);
00164             if (ifsm.is_open())
00165                 vecLabels = SVM.ReadLabels(ifsm);
00166 
00167             else {
00168                 cerr << "Could not read label file" << endl;
00169                 exit(1);
00170             }
00171             for (i = 0; i < vecLabels.size(); i++)
00172                 setLabeledGenes.insert(vecLabels[i].GeneName);
00173             cerr << "Loading Model" << endl;
00174             SVM.ReadModel(sArgs.model_arg);
00175             cerr << "Model Loaded" << endl;
00176 
00177             pTestVector[0].reserve((size_t) vecLabels.size()+1 );
00178             for (j = 0; j < vecLabels.size(); j++) {
00179                 pTestVector[0].push_back(vecLabels[j]);           
00180             }
00181 
00182 
00183             tmpAllResults = SVM.Classify(PCL,   pTestVector[0]);
00184             cerr << "Classified " << tmpAllResults.size() << " examples"<< endl;
00185             AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end());
00186             tmpAllResults.resize(0);
00187             ofstream ofsm;
00188             ofsm.clear();
00189             ofsm.open(sArgs.output_arg);
00190             SVM.PrintResults(AllResults, ofsm);
00191             return 0;
00192         }
00193     } else if (sArgs.output_given && sArgs.labels_given) {
00194         //do learning and classifying with cross validation
00195         //set up training data
00196         if( sArgs.cross_validation_arg > 1){        
00197             for (i = 0; i < sArgs.cross_validation_arg; i++) {
00198                 pTestVector[i].reserve((size_t) vecLabels.size()
00199                     / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
00200                 pTrainVector[i].reserve((size_t) vecLabels.size()
00201                     / (sArgs.cross_validation_arg)
00202                     * (sArgs.cross_validation_arg - 1)
00203                     + sArgs.cross_validation_arg);
00204                 for (j = 0; j < vecLabels.size(); j++) {
00205                     if (j % sArgs.cross_validation_arg == i) {
00206                         pTestVector[i].push_back(vecLabels[j]);
00207                     } else {
00208                         pTrainVector[i].push_back((vecLabels[j]));
00209                     }
00210                 }
00211             }
00212         }
00213         else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted
00214             // no holdout so train is the same as test gene set
00215             pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00216             pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00217 
00218             for (j = 0; j < vecLabels.size(); j++) {
00219                 pTestVector[0].push_back(vecLabels[j]);           
00220                 pTrainVector[0].push_back(vecLabels[j]);            
00221             }
00222 
00223         }
00224         //set up training data done
00225 
00226         //set up validation data
00227         vector<SVMLabel> vec_allUnlabeledLabels;
00228         vector<Result> vec_allUnlabeledResults;
00229         vector<Result> vec_tmpUnlabeledResults;
00230         if (sArgs.all_flag) {
00231             vec_allUnlabeledLabels.reserve(PCL.GetGenes());
00232             vec_allUnlabeledResults.reserve(PCL.GetGenes());
00233             for (i = 0; i < PCL.GetGenes(); i++) {
00234                 if (setLabeledGenes.find(PCL.GetGene(i))
00235                     == setLabeledGenes.end()) {
00236                         vec_allUnlabeledLabels.push_back(
00237                             SVMLabel(PCL.GetGene(i), 0));
00238                         vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i)));
00239                 }
00240             }
00241         }
00242         //run once
00243         for (i = 0; i < sArgs.cross_validation_arg; i++) {
00244             pTrainSample = SVM.CreateSample(PCL,
00245                 pTrainVector[i]);
00246 
00247             cerr << "Cross Validation Trial " << i << endl;
00248             SVM.Learn(*pTrainSample);
00249             cerr << "Learned" << endl;
00250             if (i > -1) {
00251                 SVMArc::CSVMSTRUCTTREE::FreeSample(*pTrainSample);
00252             }
00253             tmpAllResults = SVM.Classify(PCL,   pTestVector[i]);
00254             cerr << "Classified " << tmpAllResults.size() << " examples"<< endl;
00255 
00256 
00257 
00258             AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end());
00259             tmpAllResults.resize(0);
00260 
00261             if(i == (sArgs.cross_validation_arg-1)){
00262                 if (sArgs.all_flag || sArgs.model_given ) {
00263                     if(sArgs.cross_validation_arg!=1){
00264                         pTrainSample = SVM.CreateSample(PCL, vecLabels);
00265                         cerr << "Train with All Labeled Data " <<  endl;
00266                         SVM.Learn(*pTrainSample);
00267                         cerr << "Learned" << endl;
00268                         if (i > -1) {
00269                             SVMArc::CSVMSTRUCTTREE::FreeSample(*pTrainSample);
00270                         }
00271                     }
00272                     if (sArgs.model_given ){  //learn once and write to file
00273                         SVM.WriteModel(sArgs.model_arg);
00274                         cerr <<" Model Writen to file "<<sArgs.model_arg<<endl;
00275                     }
00276                     if(sArgs.all_flag){
00277                         vec_allUnlabeledResults = SVM.Classify(PCL, vec_allUnlabeledLabels);
00278                         cerr << "Classified " << vec_allUnlabeledResults.size() << " examples"<< endl;
00279                     }
00280 
00281                 }
00282             }
00283 
00284 
00285 
00286 
00287 
00288 
00289         }
00290 
00291         if (sArgs.all_flag) { //add the unlabeled results
00292 
00293             AllResults.insert(AllResults.end(),
00294                 vec_allUnlabeledResults.begin(),
00295                 vec_allUnlabeledResults.end());
00296         }
00297 
00298         ofstream ofsm;
00299         ofsm.clear();
00300         ofsm.open(sArgs.output_arg);
00301         SVM.PrintResults(AllResults, ofsm);
00302         return 0;
00303 
00304     } else {
00305         cerr << "More options are needed" << endl;
00306     }
00307 
00308 }
00309