Sleipnir
|
00001 #include <fstream> 00002 00003 #include <vector> 00004 #include <queue> 00005 00006 /***************************************************************************** 00007 * This file is provided under the Creative Commons Attribution 3.0 license. 00008 * 00009 * You are free to share, copy, distribute, transmit, or adapt this work 00010 * PROVIDED THAT you attribute the work to the authors listed below. 00011 * For more information, please see the following web page: 00012 * http://creativecommons.org/licenses/by/3.0/ 00013 * 00014 * This file is a component of the Sleipnir library for functional genomics, 00015 * authored by: 00016 * Curtis Huttenhower (chuttenh@princeton.edu) 00017 * Mark Schroeder 00018 * Maria D. Chikina 00019 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00020 * 00021 * If you use this library, the included executable tools, or any related 00022 * code in your work, please cite the following publication: 00023 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00024 * Olga G. Troyanskaya. 00025 * "The Sleipnir library for computational functional genomics" 00026 *****************************************************************************/ 00027 #include "stdafx.h" 00028 #include "cmdline.h" 00029 #include "statistics.h" 00030 00031 using namespace LIBSVM; 00032 00033 vector<LIBSVM::SVMLabel> ReadLabels(ifstream & ifsm) { 00034 00035 static const size_t c_iBuffer = 1024; 00036 char acBuffer[c_iBuffer]; 00037 vector<string> vecstrTokens; 00038 vector<LIBSVM::SVMLabel> vecLabels; 00039 size_t numPositives, numNegatives; 00040 numPositives = numNegatives = 0; 00041 while (!ifsm.eof()) { 00042 ifsm.getline(acBuffer, c_iBuffer - 1); 00043 acBuffer[c_iBuffer - 1] = 0; 00044 vecstrTokens.clear(); 00045 CMeta::Tokenize(acBuffer, vecstrTokens); 00046 if (vecstrTokens.empty()) 00047 continue; 00048 if (vecstrTokens.size() != 2) { 00049 cerr << "Illegal label line (" << vecstrTokens.size() << "): " 00050 << acBuffer << endl; 00051 continue; 00052 } 00053 vecLabels.push_back(LIBSVM::SVMLabel(vecstrTokens[0], atof( 00054 vecstrTokens[1].c_str()))); 00055 if (vecLabels.back().Target > 0) 00056 numPositives++; 00057 else 00058 numNegatives++; 00059 } 00060 return vecLabels; 00061 } 00062 00063 00064 struct SortResults { 00065 00066 bool operator()(const LIBSVM::Result& rOne, const LIBSVM::Result & rTwo) const { 00067 return (rOne.Value > rTwo.Value); 00068 } 00069 }; 00070 00071 00072 size_t PrintResults(vector<LIBSVM::Result> vecResults, ofstream & ofsm) { 00073 sort(vecResults.begin(), vecResults.end(), SortResults()); 00074 int LabelVal; 00075 for (size_t i = 0; i < vecResults.size(); i++) { 00076 ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t' 00077 << vecResults[i].Value << endl; 00078 } 00079 }; 00080 00081 struct ParamStruct { 00082 vector<float> vecK, vecTradeoff; 00083 vector<size_t> vecLoss; 00084 vector<char*> vecNames; 00085 }; 00086 00087 int main(int iArgs, char** aszArgs) { 00088 00089 gengetopt_args_info sArgs; 00090 00091 CPCL PCL;//data 00092 LIBSVM::CLIBSVM SVM;//model 00093 00094 size_t i, j, iGene, jGene; 00095 ifstream ifsm; 00096 00097 if (cmdline_parser(iArgs, aszArgs, &sArgs)) { 00098 cmdline_parser_print_help(); 00099 return 1; 00100 } 00101 00102 //Set model parameters 00103 00104 if (sArgs.cross_validation_arg < 1){ 00105 cerr << "cross_valid is <1. Must be set at least 1" << endl; 00106 return 1; 00107 } 00108 else if(sArgs.cross_validation_arg < 2){ 00109 cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl; 00110 if(sArgs.num_cv_runs_arg > 1){ 00111 cerr << "number of cv runs is > 1. When no cv holdouts, must be set to 1." << endl; 00112 return 1; 00113 } 00114 } 00115 00116 if (sArgs.num_cv_runs_arg < 1){ 00117 cerr << "number of cv runs is < 1. Must be set at least 1" << endl; 00118 return 1; 00119 } 00120 00121 00122 00123 SVM.SetTradeoff(sArgs.tradeoff_arg); 00124 SVM.SetNu(sArgs.nu_arg); 00125 SVM.SetSVMType(sArgs.svm_type_arg); 00126 SVM.SetBalance(sArgs.balance_flag); 00127 00128 if (!SVM.parms_check()) { 00129 cerr << "Sanity check failed, see above errors" << endl; 00130 return 1; 00131 } 00132 00133 //TODO: allow multiple PCL files 00134 //size_t iFile; //TODO 00135 // vector<string> PCLs; //TODO 00136 00137 //check data file 00138 if (sArgs.input_given) { 00139 if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) { 00140 cerr << "Could not open input PCL" << endl; 00141 return 1; 00142 } 00143 } 00144 00145 //read label files 00146 vector<LIBSVM::SVMLabel> vecLabels; 00147 set<string> setLabeledGenes; 00148 if (sArgs.labels_given) { 00149 ifsm.clear(); 00150 ifsm.open(sArgs.labels_arg); 00151 if (ifsm.is_open()) 00152 vecLabels = ReadLabels(ifsm); 00153 else { 00154 cerr << "Could not read label file" << endl; 00155 return 1; 00156 } 00157 for (i = 0; i < vecLabels.size(); i++) 00158 setLabeledGenes.insert(vecLabels[i].GeneName); 00159 } 00160 00161 if (sArgs.model_given && sArgs.labels_given) { //learn once and write to file 00162 //TODO 00163 cerr << "not yet implemented: learn once and write to file" << endl; 00164 /* 00165 pTrainSample = CLIBSVM::CreateSample(PCL, vecLabels); 00166 SVM.Learn(*pTrainSample); 00167 SVM.WriteModel(sArgs.model_arg); 00168 */ 00169 00170 } else if (sArgs.model_given && sArgs.output_given) { //read model and classify all 00171 //TODO 00172 cerr << "not yet implemetned: read model and classify all" << endl; 00173 /* 00174 vector<SVMLabel> vecAllLabels; 00175 for (size_t i = 0; i < PCL.GetGenes(); i++) 00176 vecAllLabels.push_back(SVMLabel(PCL.GetGene(i), 0)); 00177 00178 SVM.ReadModel(sArgs.model_arg); 00179 AllResults = SVM.Classify(PCL, vecAllLabels); 00180 ofstream ofsm; 00181 ofsm.open(sArgs.output_arg); 00182 if (ofsm.is_open()) 00183 PrintResults(AllResults, ofsm); 00184 else { 00185 cerr << "Could not open output file" << endl; 00186 } 00187 */ 00188 00189 } else if (sArgs.output_given && sArgs.labels_given) { 00190 00191 LIBSVM::SAMPLE* pTrainSample;//sampled data 00192 size_t numSample;//number of sampling 00193 00194 numSample = sArgs.cross_validation_arg * sArgs.num_cv_runs_arg; 00195 00196 vector<LIBSVM::SVMLabel> pTrainVector[numSample]; 00197 vector<LIBSVM::SVMLabel> pTestVector[numSample]; 00198 vector<LIBSVM::Result> AllResults; 00199 vector<LIBSVM::Result> testResults; 00200 00201 //set train and test label vectors 00202 // 00203 if( sArgs.cross_validation_arg > 1 && sArgs.num_cv_runs_arg >= 1 ){ 00204 //do learning and classifying with cross validation 00205 // 00206 size_t ii, index; 00207 00208 for (ii = 0; ii < sArgs.num_cv_runs_arg; ii++) { 00209 if(ii > 0) 00210 std::random_shuffle(vecLabels.begin(), vecLabels.end()); 00211 00212 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00213 index = sArgs.cross_validation_arg * ii + i; 00214 pTestVector[index].reserve((size_t) vecLabels.size() 00215 / sArgs.cross_validation_arg + sArgs.cross_validation_arg); 00216 pTrainVector[index].reserve((size_t) vecLabels.size() 00217 / (sArgs.cross_validation_arg) 00218 * (sArgs.cross_validation_arg - 1) 00219 + sArgs.cross_validation_arg); 00220 for (j = 0; j < vecLabels.size(); j++) { 00221 if (j % sArgs.cross_validation_arg == i) { 00222 pTestVector[index].push_back(vecLabels[j]); 00223 } else { 00224 pTrainVector[index].push_back(vecLabels[j]); 00225 } 00226 } 00227 } 00228 00229 } 00230 } 00231 else{ 00232 // if you have less than 2 fold cross, no cross validation is done, 00233 // all train genes are used and predicted 00234 // 00235 cerr << "no holdout so train is the same as test" << endl; 00236 pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00237 pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00238 00239 for (j = 0; j < vecLabels.size(); j++) { 00240 pTestVector[0].push_back(vecLabels[j]); 00241 pTrainVector[0].push_back(vecLabels[j]); 00242 } 00243 } 00244 00245 //if want to make predictions for genes (row) with no label information 00246 // 00247 vector<SVMLabel> vec_allUnlabeledLabels; 00248 vector<Result> vec_allUnlabeledResults; 00249 vector<Result> tmpUnlabeledResults; 00250 if (sArgs.all_flag) { 00251 vec_allUnlabeledLabels.reserve(PCL.GetGenes()); 00252 vec_allUnlabeledResults.reserve(PCL.GetGenes()); 00253 for (i = 0; i < PCL.GetGenes(); i++) { 00254 if (setLabeledGenes.find(PCL.GetGene(i)) 00255 == setLabeledGenes.end()) { // if gene with no label information 00256 00257 vec_allUnlabeledLabels.push_back(SVMLabel(PCL.GetGene(i), 0)); 00258 vec_allUnlabeledResults.push_back(Result(PCL.GetGene(i))); 00259 } 00260 } 00261 } 00262 00263 bool added;//flag for merging testResults and AllResults 00264 00265 //for each sample 00266 for (i = 0; i < numSample; i++) { 00267 pTrainSample = LIBSVM::CLIBSVM::CreateSample(PCL, pTrainVector[i]); 00268 cerr << "Trial " << i << endl; 00269 00270 SVM.Learn(*pTrainSample); 00271 cerr << "Learned" << endl; 00272 00273 testResults = SVM.Classify(PCL, pTestVector[i]); 00274 cerr << "Classified " << testResults.size() << " test examples" << endl; 00275 00276 // merge testResults and AllResults 00277 // TODO: make more efficent 00278 for(std::vector<LIBSVM::Result>::iterator it = testResults.begin() ; 00279 it != testResults.end() ; it ++){ 00280 00281 added = false; 00282 for(std::vector<LIBSVM::Result>::iterator ita = AllResults.begin() ; 00283 ita != AllResults.end() ; ita ++){ 00284 00285 if ( (*it).GeneName.compare((*ita).GeneName) == 0 ){ 00286 00287 (*ita).Value += (*it).Value; 00288 added = true; 00289 break; 00290 } 00291 00292 } 00293 00294 if(!added) 00295 AllResults.push_back((*it)); 00296 00297 } 00298 testResults.clear(); 00299 00300 // classify genes with no label information 00301 if (sArgs.all_flag) { 00302 tmpUnlabeledResults = SVM.Classify( 00303 PCL, vec_allUnlabeledLabels);//make predictions 00304 for (j = 0; j < tmpUnlabeledResults.size(); j++) 00305 vec_allUnlabeledResults[j].Value 00306 += tmpUnlabeledResults[j].Value; 00307 } 00308 00309 if (i > 0) { 00310 //LIBSVM::CLIBSVM::FreeSample(*pTrainSample); 00311 free(pTrainSample); 00312 } 00313 00314 //mem = CMeta::GetMemoryUsage(); 00315 00316 cerr << "end of trail" << endl; 00317 00318 } 00319 00320 // average results (svm outputs) from multiple cv runs 00321 for(std::vector<LIBSVM::Result>::iterator it = AllResults.begin(); 00322 it != AllResults.end(); ++ it){ 00323 (*it).Value /= sArgs.num_cv_runs_arg; 00324 00325 } 00326 00327 if (sArgs.all_flag) { //add the unlabeled results 00328 for (j = 0; j < vec_allUnlabeledResults.size(); j++) 00329 vec_allUnlabeledResults[j].Value 00330 /= (sArgs.cross_validation_arg * sArgs.num_cv_runs_arg); 00331 AllResults.insert(AllResults.end(), 00332 vec_allUnlabeledResults.begin(), 00333 vec_allUnlabeledResults.end()); 00334 } 00335 00336 ofstream ofsm; 00337 ofsm.clear(); 00338 ofsm.open(sArgs.output_arg); 00339 PrintResults(AllResults, ofsm); 00340 return 0; 00341 00342 } else { 00343 cerr << "More options are needed" << endl; 00344 } 00345 00346 } 00347