Sleipnir
|
00001 #include <fstream> 00002 #include <iostream> 00003 #include <iterator> 00004 #include <vector> 00005 #include <queue> 00006 00007 /***************************************************************************** 00008 * This file is provided under the Creative Commons Attribution 3.0 license. 00009 * 00010 * You are free to share, copy, distribute, transmit, or adapt this work 00011 * PROVIDED THAT you attribute the work to the authors listed below. 00012 * For more information, please see the following web page: 00013 * http://creativecommons.org/licenses/by/3.0/ 00014 * 00015 * This file is a component of the Sleipnir library for functional genomics, 00016 * authored by: 00017 * Curtis Huttenhower (chuttenh@princeton.edu) 00018 * Mark Schroeder 00019 * Maria D. Chikina 00020 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00021 * 00022 * If you use this library, the included executable tools, or any related 00023 * code in your work, please cite the following publication: 00024 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00025 * Olga G. Troyanskaya. 00026 * "The Sleipnir library for computational functional genomics" 00027 *****************************************************************************/ 00028 #include "stdafx.h" 00029 #include "cmdline.h" 00030 #include "statistics.h" 00031 00032 using namespace SVMArc; 00033 //#include "../../extlib/svm_light/svm_light/kernel.h" 00034 00035 00036 00037 00038 00039 00040 00041 int main(int iArgs, char** aszArgs) { 00042 gengetopt_args_info sArgs; 00043 00044 CPCL PCL; 00045 SVMArc::CSVMSTRUCTTREE SVM; 00046 00047 size_t i, j, k , iGene, jGene; 00048 double bestscore; 00049 ; 00050 ifstream ifsm; 00051 if (cmdline_parser(iArgs, aszArgs, &sArgs)) { 00052 cmdline_parser_print_help(); 00053 return 1; 00054 } 00055 00056 //Set Parameters 00057 SVM.SetLearningAlgorithm(sArgs.learning_algorithm_arg); 00058 SVM.SetVerbosity(sArgs.verbosity_arg); 00059 SVM.SetLossFunction(sArgs.loss_function_arg); 00060 00061 cerr << "SetLossFunction" <<sArgs.loss_function_arg<< endl; 00062 00063 if (sArgs.cross_validation_arg < 1){ 00064 cerr << "cross_valid is <1. Must be set at least 1" << endl; 00065 return 1; 00066 } 00067 else if(sArgs.cross_validation_arg < 2){ 00068 cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl; 00069 } 00070 00071 SVM.SetTradeoff(sArgs.tradeoff_arg); 00072 00073 if (sArgs.slack_flag) 00074 SVM.UseSlackRescaling(); 00075 else 00076 SVM.UseMarginRescaling(); 00077 00078 cerr << "SetRescaling" << endl; 00079 SVM.ReadOntology(sArgs.ontoparam_arg); // Read Ontology File 00080 if (!SVM.parms_check()) { 00081 cerr << "Parameter check not passed, see above errors" << endl; 00082 return 1; 00083 } 00084 00085 cerr << "Parameter check" << endl; 00086 00087 // cout << "there are " << vecLabels.size() << " labels processed" << endl; 00088 size_t iFile; 00089 vector<string> PCLs; 00090 if (sArgs.input_given) { 00091 if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) { 00092 cerr << "Could not open input PCL" << endl; 00093 return 1; 00094 } 00095 } 00096 00097 //Read labels from file 00098 vector<SVMArc::SVMLabel> vecLabels; 00099 set<string> setLabeledGenes; 00100 if (sArgs.labels_given) { 00101 ifsm.clear(); 00102 ifsm.open(sArgs.labels_arg); 00103 if (ifsm.is_open()) 00104 vecLabels = SVM.ReadLabels(ifsm); 00105 else { 00106 cerr << "Could not read label file" << endl; 00107 return 1; 00108 } 00109 for (i = 0; i < vecLabels.size(); i++) 00110 setLabeledGenes.insert(vecLabels[i].GeneName); 00111 } 00112 cerr << "Read labels from file" << endl; 00113 00114 00115 //Training 00116 SAMPLE* pTrainSample; 00117 vector<SVMArc::SVMLabel> pTrainVector[sArgs.cross_validation_arg]; 00118 vector<SVMArc::SVMLabel> pTestVector[sArgs.cross_validation_arg]; 00119 vector<SVMArc::Result> AllResults; 00120 vector<SVMArc::Result> tmpAllResults; 00121 00122 if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file 00123 pTrainSample = SVM.CreateSample(PCL, vecLabels); 00124 SVM.Learn(*pTrainSample); 00125 SVM.WriteModel(sArgs.model_arg); 00126 } else if (sArgs.model_given && sArgs.output_given) { //read model and classify all 00127 vector<SVMLabel> vecAllLabels; 00128 00129 for (size_t i = 0; i < PCL.GetGenes(); i++) 00130 vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0)); 00131 00132 SVM.ReadModel(sArgs.model_arg); 00133 AllResults = SVM.Classify(PCL, vecAllLabels); 00134 ofstream ofsm; 00135 ofsm.open(sArgs.output_arg); 00136 if (ofsm.is_open()) 00137 SVM.PrintResults(AllResults, ofsm); 00138 else { 00139 cerr << "Could not open output file" << endl; 00140 } 00141 } else if (sArgs.output_given && sArgs.labels_given) { 00142 //do learning and classifying with cross validation 00143 //set up training data 00144 if( sArgs.cross_validation_arg > 1){ 00145 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00146 pTestVector[i].reserve((size_t) vecLabels.size() 00147 / sArgs.cross_validation_arg + sArgs.cross_validation_arg); 00148 pTrainVector[i].reserve((size_t) vecLabels.size() 00149 / (sArgs.cross_validation_arg) 00150 * (sArgs.cross_validation_arg - 1) 00151 + sArgs.cross_validation_arg); 00152 for (j = 0; j < vecLabels.size(); j++) { 00153 if (j % sArgs.cross_validation_arg == i) { 00154 pTestVector[i].push_back(vecLabels[j]); 00155 } else { 00156 pTrainVector[i].push_back((vecLabels[j])); 00157 } 00158 } 00159 } 00160 } 00161 else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted 00162 // no holdout so train is the same as test gene set 00163 pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00164 pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00165 00166 for (j = 0; j < vecLabels.size(); j++) { 00167 pTestVector[0].push_back(vecLabels[j]); 00168 pTrainVector[0].push_back(vecLabels[j]); 00169 } 00170 00171 } 00172 //set up training data done 00173 00174 //set up validation data 00175 vector<SVMLabel> vec_allUnlabeledLabels; 00176 vector<Result> vec_allUnlabeledResults; 00177 vector<Result> vec_tmpUnlabeledResults; 00178 if (sArgs.all_flag) { 00179 vec_allUnlabeledLabels.reserve(PCL.GetGenes()); 00180 vec_allUnlabeledResults.reserve(PCL.GetGenes()); 00181 for (i = 0; i < PCL.GetGenes(); i++) { 00182 if (setLabeledGenes.find(PCL.GetGene(i)) 00183 == setLabeledGenes.end()) { 00184 vec_allUnlabeledLabels.push_back( 00185 SVMLabel(PCL.GetGene(i), 0)); 00186 vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i))); 00187 } 00188 } 00189 } 00190 //run once 00191 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00192 pTrainSample = SVM.CreateSample(PCL, 00193 pTrainVector[i]); 00194 00195 cerr << "Cross Validation Trial " << i << endl; 00196 SVM.Learn(*pTrainSample); 00197 cerr << "Learned" << endl; 00198 tmpAllResults = SVM.Classify(PCL, pTestVector[i]); 00199 cerr << "Classified " << tmpAllResults.size() << " examples"<< endl; 00200 AllResults.insert(AllResults.end(), tmpAllResults.begin(), tmpAllResults.end()); 00201 tmpAllResults.resize(0); 00202 if (sArgs.all_flag) { 00203 vec_tmpUnlabeledResults = SVM.Classify( 00204 PCL, vec_allUnlabeledLabels); 00205 00206 if(i == 0){ 00207 for (j = 0; j < vec_tmpUnlabeledResults.size(); j++){ 00208 vec_allUnlabeledResults[j].num_class = vec_tmpUnlabeledResults[j].num_class; 00209 for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++) 00210 vec_allUnlabeledResults[j].Scores.push_back(vec_tmpUnlabeledResults[j].Scores[k]); 00211 } 00212 } 00213 else{ 00214 for (j = 0; j < vec_tmpUnlabeledResults.size(); j++) 00215 for( k = 1; k <= vec_tmpUnlabeledResults[j].num_class; k++) 00216 vec_allUnlabeledResults[j].Scores[k] += vec_tmpUnlabeledResults[j].Scores[k]; 00217 } 00218 00219 } 00220 if (i > 0) { 00221 SVMArc::CSVMSTRUCTTREE::FreeSample(*pTrainSample); 00222 } 00223 } 00224 cerr << "5" << endl; 00225 00226 if (sArgs.all_flag) { //add the unlabeled results 00227 for (j = 0; j < vec_allUnlabeledResults.size(); j++) 00228 for( k = 1; k <= vec_allUnlabeledResults[j].num_class; k++){ 00229 if(k==1){ 00230 vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg; 00231 bestscore=vec_allUnlabeledResults[j].Scores[k]; 00232 vec_allUnlabeledResults[j].Value=k; 00233 }else{ 00234 vec_allUnlabeledResults[j].Scores[k]/= sArgs.cross_validation_arg; 00235 if(vec_allUnlabeledResults[j].Scores[k] < bestscore){ 00236 bestscore = vec_allUnlabeledResults[j].Scores[k]; 00237 vec_allUnlabeledResults[j].Value=k; 00238 } 00239 } 00240 } 00241 00242 AllResults.insert(AllResults.end(), 00243 vec_allUnlabeledResults.begin(), 00244 vec_allUnlabeledResults.end()); 00245 cerr << "6" << endl; 00246 00247 } 00248 cerr << "7" << endl; 00249 00250 ofstream ofsm; 00251 ofsm.clear(); 00252 ofsm.open(sArgs.output_arg); 00253 SVM.PrintResults(AllResults, ofsm); 00254 return 0; 00255 00256 } else { 00257 cerr << "More options are needed" << endl; 00258 } 00259 00260 } 00261