Sleipnir
src/SVMhierarchy.cpp
00001 #include <fstream>
00002 #include <iostream>
00003 #include <iterator>
00004 #include <vector>
00005 #include <queue>
00006 
00007 /*****************************************************************************
00008 * This file is provided under the Creative Commons Attribution 3.0 license.
00009 *
00010 * You are free to share, copy, distribute, transmit, or adapt this work
00011 * PROVIDED THAT you attribute the work to the authors listed below.
00012 * For more information, please see the following web page:
00013 * http://creativecommons.org/licenses/by/3.0/
00014 *
00015 * This file is a component of the Sleipnir library for functional genomics,
00016 * authored by:
00017 * Curtis Huttenhower (chuttenh@princeton.edu)
00018 * Mark Schroeder
00019 * Maria D. Chikina
00020 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00021 *
00022 * If you use this library, the included executable tools, or any related
00023 * code in your work, please cite the following publication:
00024 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00025 * Olga G. Troyanskaya.
00026 * "The Sleipnir library for computational functional genomics"
00027 *****************************************************************************/
00028 #include "stdafx.h"
00029 #include "cmdline.h"
00030 #include "statistics.h"
00031 
00032 using namespace SVMArc;
00033 //#include "../../extlib/svm_light/svm_light/kernel.h"
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00041 int main(int iArgs, char** aszArgs) {
00042     gengetopt_args_info sArgs;
00043 
00044     CPCL PCL;
00045     SVMArc::CSVMSTRUCTTREE SVM;
00046 
00047     size_t i, j, k , iGene, jGene;
00048     double bestscore;
00049 ;
00050     ifstream ifsm;
00051     if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
00052         cmdline_parser_print_help();
00053         return 1;
00054     }
00055 
00056     //Set Parameters
00057     SVM.SetLearningAlgorithm(sArgs.learning_algorithm_arg);
00058     SVM.SetVerbosity(sArgs.verbosity_arg);
00059     SVM.SetLossFunction(sArgs.loss_function_arg);
00060 
00061         cerr << "SetLossFunction" <<sArgs.loss_function_arg<< endl;
00062 
00063     if (sArgs.cross_validation_arg < 1){
00064         cerr << "cross_valid is <1. Must be set at least 1" << endl;
00065         return 1;
00066     }
00067     else if(sArgs.cross_validation_arg < 2){
00068         cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
00069     }
00070 
00071     SVM.SetTradeoff(sArgs.tradeoff_arg);
00072 
00073     if (sArgs.slack_flag)
00074         SVM.UseSlackRescaling();
00075     else
00076         SVM.UseMarginRescaling();
00077 
00078     cerr << "SetRescaling" << endl;
00079     SVM.ReadOntology(sArgs.ontoparam_arg); // Read Ontology File
00080     if (!SVM.parms_check()) {
00081         cerr << "Parameter check not passed, see above errors" << endl;
00082         return 1;
00083     }
00084 
00085             cerr << "Parameter check" << endl;
00086 
00087     //  cout << "there are " << vecLabels.size() << " labels processed" << endl;
00088     size_t iFile;
00089     vector<string> PCLs;
00090     if (sArgs.input_given) {
00091         if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) {
00092             cerr << "Could not open input PCL" << endl;
00093             return 1;
00094         }
00095     }
00096 
00097     //Read labels from file
00098     vector<SVMArc::SVMLabel> vecLabels;
00099     set<string> setLabeledGenes;
00100     if (sArgs.labels_given) {
00101         ifsm.clear();
00102         ifsm.open(sArgs.labels_arg);
00103         if (ifsm.is_open())
00104             vecLabels = SVM.ReadLabels(ifsm);
00105         else {
00106             cerr << "Could not read label file" << endl;
00107             return 1;
00108         }
00109         for (i = 0; i < vecLabels.size(); i++)
00110             setLabeledGenes.insert(vecLabels[i].GeneName);
00111     }
00112                 cerr << "Read labels from file" << endl;
00113 
00114 
00115     //Training
00116     SAMPLE* pTrainSample;
00117     vector<SVMArc::SVMLabel> pTrainVector[sArgs.cross_validation_arg];
00118     vector<SVMArc::SVMLabel> pTestVector[sArgs.cross_validation_arg];
00119     vector<SVMArc::Result> AllResults;
00120     vector<SVMArc::Result> tmpAllResults;
00121 
00122     if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file
00123         pTrainSample = SVM.CreateSample(PCL, vecLabels);
00124         SVM.Learn(*pTrainSample);
00125         SVM.WriteModel(sArgs.model_arg);
00126     } else if (sArgs.model_given && sArgs.output_given) { //read model and classify all
00127         vector<SVMLabel> vecAllLabels;
00128 
00129         for (size_t i = 0; i < PCL.GetGenes(); i++)
00130             vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0));
00131 
00132         SVM.ReadModel(sArgs.model_arg);
00133         AllResults = SVM.Classify(PCL, vecAllLabels);
00134         ofstream ofsm;
00135         ofsm.open(sArgs.output_arg);
00136         if (ofsm.is_open())
00137             SVM.PrintResults(AllResults, ofsm);
00138         else {
00139             cerr << "Could not open output file" << endl;
00140         }
00141     } else if (sArgs.output_given && sArgs.labels_given) {
00142         //do learning and classifying with cross validation
00143         //set up training data
00144         if( sArgs.cross_validation_arg > 1){        
00145             for (i = 0; i < sArgs.cross_validation_arg; i++) {
00146                 pTestVector[i].reserve((size_t) vecLabels.size()
00147                     / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
00148                 pTrainVector[i].reserve((size_t) vecLabels.size()
00149                     / (sArgs.cross_validation_arg)
00150                     * (sArgs.cross_validation_arg - 1)
00151                     + sArgs.cross_validation_arg);
00152                 for (j = 0; j < vecLabels.size(); j++) {
00153                     if (j % sArgs.cross_validation_arg == i) {
00154                         pTestVector[i].push_back(vecLabels[j]);
00155                     } else {
00156                         pTrainVector[i].push_back((vecLabels[j]));
00157                     }
00158                 }
00159             }
00160         }
00161         else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted
00162             // no holdout so train is the same as test gene set
00163             pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00164             pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00165 
00166             for (j = 0; j < vecLabels.size(); j++) {
00167                 pTestVector[0].push_back(vecLabels[j]);           
00168                 pTrainVector[0].push_back(vecLabels[j]);            
00169             }
00170 
00171         }
00172         //set up training data done
00173 
00174         //set up validation data
00175         vector<SVMLabel> vec_allUnlabeledLabels;
00176         vector<Result> vec_allUnlabeledResults;
00177         vector<Result> vec_tmpUnlabeledResults;
00178         if (sArgs.all_flag) {
00179             vec_allUnlabeledLabels.reserve(PCL.GetGenes());
00180             vec_allUnlabeledResults.reserve(PCL.GetGenes());
00181             for (i = 0; i < PCL.GetGenes(); i++) {
00182                 if (setLabeledGenes.find(PCL.GetGene(i))
00183                     == setLabeledGenes.end()) {
00184                         vec_allUnlabeledLabels.push_back(
00185                             SVMLabel(PCL.GetGene(i), 0));
00186                         vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i)));
00187                 }
00188             }
00189         }
00190         //run once
00191         for (i = 0; i < sArgs.cross_validation_arg; i++) {
00192             pTrainSample = SVM.CreateSample(PCL,
00193                 pTrainVector[i]);
00194 
00195             cerr << "Cross Validation Trial " << i << endl;
00196             SVM.Learn(*pTrainSample);
00197             cerr << "Learned" << endl;
00198             tmpAllResults = SVM.Classify(PCL,   pTestVector[i]);
00199             cerr << "Classified " << tmpAllResults.size() << " examples"<< endl;
00200             AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end());
00201             tmpAllResults.resize(0);
00202             if (sArgs.all_flag) {
00203                 vec_tmpUnlabeledResults = SVM.Classify(
00204                     PCL, vec_allUnlabeledLabels);
00205                 
00206                 if(i == 0){
00207                     for (j = 0; j < vec_tmpUnlabeledResults.size(); j++){
00208                         vec_allUnlabeledResults[j].num_class = vec_tmpUnlabeledResults[j].num_class;
00209                         for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++)
00210                             vec_allUnlabeledResults[j].Scores.push_back(vec_tmpUnlabeledResults[j].Scores[k]);
00211                     }
00212                 }
00213                 else{
00214                     for (j = 0; j < vec_tmpUnlabeledResults.size(); j++)
00215                         for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++)
00216                             vec_allUnlabeledResults[j].Scores[k] += vec_tmpUnlabeledResults[j].Scores[k];
00217                 }
00218 
00219             }
00220             if (i > 0) {
00221                 SVMArc::CSVMSTRUCTTREE::FreeSample(*pTrainSample);
00222             }
00223         }
00224                     cerr << "5" << endl;
00225 
00226         if (sArgs.all_flag) { //add the unlabeled results
00227             for (j = 0; j < vec_allUnlabeledResults.size(); j++)
00228                 for( k = 1; k <= vec_allUnlabeledResults[j].num_class; k++){
00229                     if(k==1){
00230                         vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg;
00231                         bestscore=vec_allUnlabeledResults[j].Scores[k];
00232                         vec_allUnlabeledResults[j].Value=k;
00233                     }else{
00234                         vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg;
00235                         if(vec_allUnlabeledResults[j].Scores[k] < bestscore){
00236                             bestscore = vec_allUnlabeledResults[j].Scores[k];
00237                             vec_allUnlabeledResults[j].Value=k;
00238                         }
00239                     }
00240                 }
00241 
00242             AllResults.insert(AllResults.end(),
00243                 vec_allUnlabeledResults.begin(),
00244                 vec_allUnlabeledResults.end());
00245                         cerr << "6" << endl;
00246 
00247         }
00248                     cerr << "7" << endl;
00249 
00250         ofstream ofsm;
00251         ofsm.clear();
00252         ofsm.open(sArgs.output_arg);
00253         SVM.PrintResults(AllResults, ofsm);
00254         return 0;
00255 
00256     } else {
00257         cerr << "More options are needed" << endl;
00258     }
00259 
00260 }
00261