Sleipnir
|
00001 #include <fstream> 00002 #include <iostream> 00003 #include <iterator> 00004 #include <vector> 00005 #include <queue> 00006 00007 /***************************************************************************** 00008 * This file is provided under the Creative Commons Attribution 3.0 license. 00009 * 00010 * You are free to share, copy, distribute, transmit, or adapt this work 00011 * PROVIDED THAT you attribute the work to the authors listed below. 00012 * For more information, please see the following web page: 00013 * http://creativecommons.org/licenses/by/3.0/ 00014 * 00015 * This file is a component of the Sleipnir library for functional genomics, 00016 * authored by: 00017 * Curtis Huttenhower (chuttenh@princeton.edu) 00018 * Mark Schroeder 00019 * Maria D. Chikina 00020 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00021 * 00022 * If you use this library, the included executable tools, or any related 00023 * code in your work, please cite the following publication: 00024 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00025 * Olga G. Troyanskaya. 00026 * "The Sleipnir library for computational functional genomics" 00027 *****************************************************************************/ 00028 #include "stdafx.h" 00029 #include "cmdline.h" 00030 #include "statistics.h" 00031 00032 using namespace SVMArc; 00033 //#include "../../extlib/svm_light/svm_light/kernel.h" 00034 00035 00036 00037 00038 00039 00040 00041 int main(int iArgs, char** aszArgs) { 00042 gengetopt_args_info sArgs; 00043 00044 CPCL PCL; 00045 CDat DAT; 00046 SVMArc::CSVMSTRUCTTREE SVM; 00047 00048 size_t i, j, k , iGene, jGene; 00049 double bestscore; 00050 ; 00051 ifstream ifsm; 00052 if (cmdline_parser(iArgs, aszArgs, &sArgs)) { 00053 cmdline_parser_print_help(); 00054 return 1; 00055 } 00056 00057 //Set Parameters 00058 SVM.SetLearningAlgorithm(sArgs.learning_algorithm_arg); 00059 SVM.SetVerbosity(sArgs.verbosity_arg); 00060 SVM.SetLossFunction(sArgs.loss_function_arg); 00061 SVM.SetNThreads(sArgs.threads_arg); 00062 cerr << "SetLossFunction: " <<sArgs.loss_function_arg<< endl; 00063 00064 if (sArgs.cross_validation_arg < 1){ 00065 cerr << "cross_valid is <1. Must be set at least 1" << endl; 00066 return 1; 00067 } 00068 else if(sArgs.cross_validation_arg < 2){ 00069 cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl; 00070 } 00071 00072 SVM.SetTradeoff(sArgs.tradeoff_arg); 00073 SVM.SetEpsilon(sArgs.epsilon_arg); 00074 if (sArgs.slack_flag) 00075 SVM.UseSlackRescaling(); 00076 else 00077 SVM.UseMarginRescaling(); 00078 00079 SVM.ReadOntology(sArgs.ontoparam_arg); // Read Ontology File 00080 00081 if (!SVM.parms_check()) { 00082 cerr << "Parameter check not passed, see above errors" << endl; 00083 return 1; 00084 } 00085 00086 //Read labels from file 00087 vector<SVMArc::SVMLabel> vecLabels; 00088 set<string> setLabeledGenes; 00089 if (sArgs.labels_given) { 00090 ifsm.clear(); 00091 ifsm.open(sArgs.labels_arg); 00092 if (ifsm.is_open()) 00093 vecLabels = SVM.ReadLabels(ifsm); 00094 else { 00095 cerr << "Could not read label file" << endl; 00096 exit(1); 00097 } 00098 for (i = 0; i < vecLabels.size(); i++) 00099 setLabeledGenes.insert(vecLabels[i].GeneName); 00100 cerr << "Read labels from file" << endl; 00101 SVM.InitializeLikAfterReadLabels(); 00102 } 00103 00104 00105 00106 00107 00108 // cout << "there are " << vecLabels.size() << " labels processed" << endl; 00109 size_t iFile; 00110 vector<string> PCLs; 00111 00112 if (sArgs.input_given) { 00113 cerr << "Loading PCL file" << endl; 00114 if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) { 00115 cerr << "Could not open input PCL" << endl; 00116 exit(1); 00117 } 00118 cerr << "PCL file Loaded" << endl; 00119 } 00120 //else if (sArgs.dab_input_given){ 00121 // cerr << "Loading DAT/DAB file" << endl; 00122 // if (!DAT.Open(sArgs.input_arg, !!sArgs.mmap_flag)) { 00123 // cerr << "Could not open input DAT/DAB file" << endl; 00124 // exit(1); 00125 // } 00126 //} 00127 // 00128 // 00129 00130 00131 00132 00133 //Training 00134 SAMPLE* pTrainSample; 00135 vector<SVMArc::SVMLabel> pTrainVector[sArgs.cross_validation_arg]; 00136 vector<SVMArc::SVMLabel> pTestVector[sArgs.cross_validation_arg]; 00137 vector<SVMArc::Result> AllResults; 00138 vector<SVMArc::Result> tmpAllResults; 00139 00140 if (sArgs.model_given && sArgs.output_given && (!sArgs.labels_given) ) { 00141 if(!sArgs.test_labels_given){//read model and classify all 00142 vector<SVMLabel> vecAllLabels; 00143 00144 for (size_t i = 0; i < PCL.GetGenes(); i++) 00145 vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0)); 00146 00147 SVM.ReadModel(sArgs.model_arg); 00148 cerr << "Model Loaded" << endl; 00149 00150 AllResults = SVM.Classify(PCL, vecAllLabels); 00151 ofstream ofsm; 00152 ofsm.open(sArgs.output_arg); 00153 if (ofsm.is_open()) 00154 SVM.PrintResults(AllResults, ofsm); 00155 else { 00156 cerr << "Could not open output file" << endl; 00157 00158 } 00159 } 00160 else//read model and classify only test examples 00161 { 00162 ifsm.clear(); 00163 ifsm.open(sArgs.test_labels_arg); 00164 if (ifsm.is_open()) 00165 vecLabels = SVM.ReadLabels(ifsm); 00166 00167 else { 00168 cerr << "Could not read label file" << endl; 00169 exit(1); 00170 } 00171 for (i = 0; i < vecLabels.size(); i++) 00172 setLabeledGenes.insert(vecLabels[i].GeneName); 00173 cerr << "Loading Model" << endl; 00174 SVM.ReadModel(sArgs.model_arg); 00175 cerr << "Model Loaded" << endl; 00176 00177 pTestVector[0].reserve((size_t) vecLabels.size()+1 ); 00178 for (j = 0; j < vecLabels.size(); j++) { 00179 pTestVector[0].push_back(vecLabels[j]); 00180 } 00181 00182 00183 tmpAllResults = SVM.Classify(PCL, pTestVector[0]); 00184 cerr << "Classified " << tmpAllResults.size() << " examples"<< endl; 00185 AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end()); 00186 tmpAllResults.resize(0); 00187 ofstream ofsm; 00188 ofsm.clear(); 00189 ofsm.open(sArgs.output_arg); 00190 SVM.PrintResults(AllResults, ofsm); 00191 return 0; 00192 } 00193 } else if (sArgs.output_given && sArgs.labels_given) { 00194 //do learning and classifying with cross validation 00195 //set up training data 00196 if( sArgs.cross_validation_arg > 1){ 00197 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00198 pTestVector[i].reserve((size_t) vecLabels.size() 00199 / sArgs.cross_validation_arg + sArgs.cross_validation_arg); 00200 pTrainVector[i].reserve((size_t) vecLabels.size() 00201 / (sArgs.cross_validation_arg) 00202 * (sArgs.cross_validation_arg - 1) 00203 + sArgs.cross_validation_arg); 00204 for (j = 0; j < vecLabels.size(); j++) { 00205 if (j % sArgs.cross_validation_arg == i) { 00206 pTestVector[i].push_back(vecLabels[j]); 00207 } else { 00208 pTrainVector[i].push_back((vecLabels[j])); 00209 } 00210 } 00211 } 00212 } 00213 else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted 00214 // no holdout so train is the same as test gene set 00215 pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00216 pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00217 00218 for (j = 0; j < vecLabels.size(); j++) { 00219 pTestVector[0].push_back(vecLabels[j]); 00220 pTrainVector[0].push_back(vecLabels[j]); 00221 } 00222 00223 } 00224 //set up training data done 00225 00226 //set up validation data 00227 vector<SVMLabel> vec_allUnlabeledLabels; 00228 vector<Result> vec_allUnlabeledResults; 00229 vector<Result> vec_tmpUnlabeledResults; 00230 if (sArgs.all_flag) { 00231 vec_allUnlabeledLabels.reserve(PCL.GetGenes()); 00232 vec_allUnlabeledResults.reserve(PCL.GetGenes()); 00233 for (i = 0; i < PCL.GetGenes(); i++) { 00234 if (setLabeledGenes.find(PCL.GetGene(i)) 00235 == setLabeledGenes.end()) { 00236 vec_allUnlabeledLabels.push_back( 00237 SVMLabel(PCL.GetGene(i), 0)); 00238 vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i))); 00239 } 00240 } 00241 } 00242 //run once 00243 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00244 pTrainSample = SVM.CreateSample(PCL, 00245 pTrainVector[i]); 00246 00247 cerr << "Cross Validation Trial " << i << endl; 00248 SVM.Learn(*pTrainSample); 00249 cerr << "Learned" << endl; 00250 if (i > -1) { 00251 SVMArc::CSVMSTRUCTTREE::FreeSample(*pTrainSample); 00252 } 00253 tmpAllResults = SVM.Classify(PCL, pTestVector[i]); 00254 cerr << "Classified " << tmpAllResults.size() << " examples"<< endl; 00255 00256 00257 00258 AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end()); 00259 tmpAllResults.resize(0); 00260 00261 if(i == (sArgs.cross_validation_arg-1)){ 00262 if (sArgs.all_flag || sArgs.model_given ) { 00263 if(sArgs.cross_validation_arg!=1){ 00264 pTrainSample = SVM.CreateSample(PCL, vecLabels); 00265 cerr << "Train with All Labeled Data " << endl; 00266 SVM.Learn(*pTrainSample); 00267 cerr << "Learned" << endl; 00268 if (i > -1) { 00269 SVMArc::CSVMSTRUCTTREE::FreeSample(*pTrainSample); 00270 } 00271 } 00272 if (sArgs.model_given ){ //learn once and write to file 00273 SVM.WriteModel(sArgs.model_arg); 00274 cerr <<" Model Writen to file "<<sArgs.model_arg<<endl; 00275 } 00276 if(sArgs.all_flag){ 00277 vec_allUnlabeledResults = SVM.Classify(PCL, vec_allUnlabeledLabels); 00278 cerr << "Classified " << vec_allUnlabeledResults.size() << " examples"<< endl; 00279 } 00280 00281 } 00282 } 00283 00284 00285 00286 00287 00288 00289 } 00290 00291 if (sArgs.all_flag) { //add the unlabeled results 00292 00293 AllResults.insert(AllResults.end(), 00294 vec_allUnlabeledResults.begin(), 00295 vec_allUnlabeledResults.end()); 00296 } 00297 00298 ofstream ofsm; 00299 ofsm.clear(); 00300 ofsm.open(sArgs.output_arg); 00301 SVM.PrintResults(AllResults, ofsm); 00302 return 0; 00303 00304 } else { 00305 cerr << "More options are needed" << endl; 00306 } 00307 00308 } 00309