Sleipnir
|
00001 #include <fstream> 00002 #include <iostream> 00003 #include <iterator> 00004 #include <vector> 00005 #include <queue> 00006 00007 /***************************************************************************** 00008 * This file is provided under the Creative Commons Attribution 3.0 license. 00009 * 00010 * You are free to share, copy, distribute, transmit, or adapt this work 00011 * PROVIDED THAT you attribute the work to the authors listed below. 00012 * For more information, please see the following web page: 00013 * http://creativecommons.org/licenses/by/3.0/ 00014 * 00015 * This file is a component of the Sleipnir library for functional genomics, 00016 * authored by: 00017 * Curtis Huttenhower (chuttenh@princeton.edu) 00018 * Mark Schroeder 00019 * Maria D. Chikina 00020 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00021 * 00022 * If you use this library, the included executable tools, or any related 00023 * code in your work, please cite the following publication: 00024 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00025 * Olga G. Troyanskaya. 00026 * "The Sleipnir library for computational functional genomics" 00027 *****************************************************************************/ 00028 #include "stdafx.h" 00029 #include "cmdline.h" 00030 #include "statistics.h" 00031 00032 using namespace SVMArc; 00033 //#include "../../extlib/svm_light/svm_light/kernel.h" 00034 00035 00036 00037 00038 00039 int main(int iArgs, char** aszArgs) { 00040 gengetopt_args_info sArgs; 00041 00042 CPCL PCL; 00043 SVMArc::CSVMSTRUCTMC SVM; 00044 00045 size_t i, j, k , iGene, jGene; 00046 double bestscore; 00047 ; 00048 ifstream ifsm; 00049 if (cmdline_parser(iArgs, aszArgs, &sArgs)) { 00050 cmdline_parser_print_help(); 00051 return 1; 00052 } 00053 00054 //Set Parameters 00055 SVM.SetLearningAlgorithm(sArgs.learning_algorithm_arg); 00056 SVM.SetVerbosity(sArgs.verbosity_arg); 00057 SVM.SetLossFunction(sArgs.loss_function_arg); 00058 00059 00060 if (sArgs.cross_validation_arg < 1){ 00061 cerr << "cross_valid is <1. Must be set at least 1" << endl; 00062 return 1; 00063 } 00064 else if(sArgs.cross_validation_arg < 2){ 00065 cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl; 00066 } 00067 00068 SVM.SetTradeoff(sArgs.tradeoff_arg); 00069 if (sArgs.slack_flag) 00070 SVM.UseSlackRescaling(); 00071 else 00072 SVM.UseMarginRescaling(); 00073 00074 00075 if (!SVM.parms_check()) { 00076 cerr << "Parameter check not passed, see above errors" << endl; 00077 return 1; 00078 } 00079 00080 // cout << "there are " << vecLabels.size() << " labels processed" << endl; 00081 size_t iFile; 00082 vector<string> PCLs; 00083 if (sArgs.input_given) { 00084 if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) { 00085 cerr << "Could not open input PCL" << endl; 00086 return 1; 00087 } 00088 } 00089 00090 //Read labels from file 00091 vector<SVMArc::SVMLabel> vecLabels; 00092 set<string> setLabeledGenes; 00093 if (sArgs.labels_given) { 00094 ifsm.clear(); 00095 ifsm.open(sArgs.labels_arg); 00096 if (ifsm.is_open()) 00097 vecLabels = SVM.ReadLabels(ifsm); 00098 else { 00099 cerr << "Could not read label file" << endl; 00100 return 1; 00101 } 00102 for (i = 0; i < vecLabels.size(); i++) 00103 setLabeledGenes.insert(vecLabels[i].GeneName); 00104 } 00105 00106 00107 //Training 00108 SAMPLE* pTrainSample; 00109 vector<SVMArc::SVMLabel> pTrainVector[sArgs.cross_validation_arg]; 00110 vector<SVMArc::SVMLabel> pTestVector[sArgs.cross_validation_arg]; 00111 vector<SVMArc::Result> AllResults; 00112 vector<SVMArc::Result> tmpAllResults; 00113 00114 if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file 00115 pTrainSample = SVM.CreateSample(PCL, vecLabels); 00116 SVM.Learn(*pTrainSample); 00117 SVM.WriteModel(sArgs.model_arg); 00118 } else if (sArgs.model_given && sArgs.output_given) { //read model and classify all 00119 vector<SVMLabel> vecAllLabels; 00120 00121 for (size_t i = 0; i < PCL.GetGenes(); i++) 00122 vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0)); 00123 00124 SVM.ReadModel(sArgs.model_arg); 00125 AllResults = SVM.Classify(PCL, vecAllLabels); 00126 ofstream ofsm; 00127 ofsm.open(sArgs.output_arg); 00128 if (ofsm.is_open()) 00129 SVM.PrintResults(AllResults, ofsm); 00130 else { 00131 cerr << "Could not open output file" << endl; 00132 } 00133 } else if (sArgs.output_given && sArgs.labels_given) { 00134 //do learning and classifying with cross validation 00135 //set up training data 00136 if( sArgs.cross_validation_arg > 1){ 00137 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00138 pTestVector[i].reserve((size_t) vecLabels.size() 00139 / sArgs.cross_validation_arg + sArgs.cross_validation_arg); 00140 pTrainVector[i].reserve((size_t) vecLabels.size() 00141 / (sArgs.cross_validation_arg) 00142 * (sArgs.cross_validation_arg - 1) 00143 + sArgs.cross_validation_arg); 00144 for (j = 0; j < vecLabels.size(); j++) { 00145 if (j % sArgs.cross_validation_arg == i) { 00146 pTestVector[i].push_back(vecLabels[j]); 00147 } else { 00148 pTrainVector[i].push_back((vecLabels[j])); 00149 } 00150 } 00151 } 00152 } 00153 else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted 00154 00155 // no holdout so train is the same as test gene set 00156 pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00157 pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00158 00159 for (j = 0; j < vecLabels.size(); j++) { 00160 pTestVector[0].push_back(vecLabels[j]); 00161 pTrainVector[0].push_back(vecLabels[j]); 00162 } 00163 } 00164 //set up training data done 00165 00166 //set up validation data 00167 vector<SVMLabel> vec_allUnlabeledLabels; 00168 vector<Result> vec_allUnlabeledResults; 00169 vector<Result> vec_tmpUnlabeledResults; 00170 if (sArgs.all_flag) { 00171 vec_allUnlabeledLabels.reserve(PCL.GetGenes()); 00172 vec_allUnlabeledResults.reserve(PCL.GetGenes()); 00173 for (i = 0; i < PCL.GetGenes(); i++) { 00174 if (setLabeledGenes.find(PCL.GetGene(i)) 00175 == setLabeledGenes.end()) { 00176 vec_allUnlabeledLabels.push_back( 00177 SVMLabel(PCL.GetGene(i), 0)); 00178 vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i))); 00179 } 00180 } 00181 } 00182 //run once 00183 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00184 pTrainSample = SVM.CreateSample(PCL, 00185 pTrainVector[i]); 00186 00187 cerr << "Cross Validation Trial " << i << endl; 00188 SVM.Learn(*pTrainSample); 00189 cerr << "Learned" << endl; 00190 tmpAllResults = SVM.Classify(PCL, pTestVector[i]); 00191 cerr << "Classified " << tmpAllResults.size() << " examples"<< endl; 00192 AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end()); 00193 tmpAllResults.resize(0); 00194 if (sArgs.all_flag) { 00195 vec_tmpUnlabeledResults = SVM.Classify( 00196 PCL, vec_allUnlabeledLabels); 00197 00198 if(i == 0){ 00199 for (j = 0; j < vec_tmpUnlabeledResults.size(); j++){ 00200 vec_allUnlabeledResults[j].num_class = vec_tmpUnlabeledResults[j].num_class; 00201 for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++) 00202 vec_allUnlabeledResults[j].Scores.push_back(vec_tmpUnlabeledResults[j].Scores[k]); 00203 } 00204 } 00205 else{ 00206 for (j = 0; j < vec_tmpUnlabeledResults.size(); j++) 00207 for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++) 00208 vec_allUnlabeledResults[j].Scores[k] += vec_tmpUnlabeledResults[j].Scores[k]; 00209 } 00210 00211 } 00212 if (i > 0) { 00213 SVMArc::CSVMSTRUCTMC::FreeSample(*pTrainSample); 00214 } 00215 } 00216 00217 if (sArgs.all_flag) { //add the unlabeled results 00218 for (j = 0; j < vec_allUnlabeledResults.size(); j++) 00219 for( k = 1; k <= vec_allUnlabeledResults[j].num_class; k++){ 00220 if(k==1){ 00221 vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg; 00222 bestscore=vec_allUnlabeledResults[j].Scores[k]; 00223 vec_allUnlabeledResults[j].Value=k; 00224 }else{ 00225 vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg; 00226 if(vec_allUnlabeledResults[j].Scores[k] < bestscore){ 00227 bestscore = vec_allUnlabeledResults[j].Scores[k]; 00228 vec_allUnlabeledResults[j].Value=k; 00229 } 00230 } 00231 } 00232 00233 AllResults.insert(AllResults.end(), 00234 vec_allUnlabeledResults.begin(), 00235 vec_allUnlabeledResults.end()); 00236 } 00237 00238 ofstream ofsm; 00239 ofsm.clear(); 00240 ofsm.open(sArgs.output_arg); 00241 SVM.PrintResults(AllResults, ofsm); 00242 return 0; 00243 00244 } else { 00245 cerr << "More options are needed" << endl; 00246 } 00247 00248 } 00249