Sleipnir
|
00001 #include <fstream> 00002 00003 #include <vector> 00004 #include <queue> 00005 00006 /***************************************************************************** 00007 * This file is provided under the Creative Commons Attribution 3.0 license. 00008 * 00009 * You are free to share, copy, distribute, transmit, or adapt this work 00010 * PROVIDED THAT you attribute the work to the authors listed below. 00011 * For more information, please see the following web page: 00012 * http://creativecommons.org/licenses/by/3.0/ 00013 * 00014 * This file is a component of the Sleipnir library for functional genomics, 00015 * authored by: 00016 * Curtis Huttenhower (chuttenh@princeton.edu) 00017 * Mark Schroeder 00018 * Maria D. Chikina 00019 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00020 * 00021 * If you use this library, the included executable tools, or any related 00022 * code in your work, please cite the following publication: 00023 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00024 * Olga G. Troyanskaya. 00025 * "The Sleipnir library for computational functional genomics" 00026 *****************************************************************************/ 00027 #include "stdafx.h" 00028 #include "cmdline.h" 00029 #include "statistics.h" 00030 00031 using namespace SVMLight; 00032 //#include "../../extlib/svm_light/svm_light/kernel.h" 00033 00034 inline bool file_exists (const std::string& name) { 00035 struct stat buffer; 00036 return (stat (name.c_str(), &buffer) == 0); 00037 } 00038 00039 vector< pair< string, string > > ReadLabelList(ifstream & ifsm, string output_prefix) { 00040 static const size_t c_iBuffer = 1024; 00041 char acBuffer[c_iBuffer]; 00042 vector<string> vecstrTokens; 00043 vector< pair < string, string > > inout; 00044 while (!ifsm.eof()) { 00045 ifsm.getline(acBuffer, c_iBuffer - 1); 00046 acBuffer[c_iBuffer - 1] = 0; 00047 vecstrTokens.clear(); 00048 CMeta::Tokenize(acBuffer, vecstrTokens); 00049 if (vecstrTokens.empty()) 00050 continue; 00051 if (vecstrTokens.size() != 2) { 00052 cerr << "Illegal inout line (" << vecstrTokens.size() << "): " 00053 << acBuffer << endl; 00054 continue; 00055 } 00056 00057 if( file_exists( output_prefix + "/" + vecstrTokens[1] ) ){ 00058 continue; 00059 } 00060 00061 00062 //cout << file_exists( vecstrTokens[1] ) << endl; 00063 00064 inout.push_back( make_pair( vecstrTokens[0], vecstrTokens[1] ) ); 00065 } 00066 cout << inout.size() << " number of label files." << endl; 00067 return inout; 00068 00069 } 00070 00071 vector<SVMLight::SVMLabel> ReadLabels(ifstream & ifsm) { 00072 00073 static const size_t c_iBuffer = 1024; 00074 char acBuffer[c_iBuffer]; 00075 vector<string> vecstrTokens; 00076 vector<SVMLight::SVMLabel> vecLabels; 00077 size_t numPositives, numNegatives; 00078 numPositives = numNegatives = 0; 00079 while (!ifsm.eof()) { 00080 ifsm.getline(acBuffer, c_iBuffer - 1); 00081 acBuffer[c_iBuffer - 1] = 0; 00082 vecstrTokens.clear(); 00083 CMeta::Tokenize(acBuffer, vecstrTokens); 00084 if (vecstrTokens.empty()) 00085 continue; 00086 if (vecstrTokens.size() != 2) { 00087 cerr << "Illegal label line (" << vecstrTokens.size() << "): " 00088 << acBuffer << endl; 00089 continue; 00090 } 00091 //cout << vecstrTokens[0] << endl; 00092 //cout << vecstrTokens[1] << endl; 00093 00094 00095 vecLabels.push_back(SVMLight::SVMLabel(vecstrTokens[0], atof( 00096 vecstrTokens[1].c_str()))); 00097 if (vecLabels.back().Target > 0) 00098 numPositives++; 00099 else 00100 numNegatives++; 00101 } 00102 00103 cout << numPositives << endl; 00104 cout << numNegatives << endl; 00105 00106 return vecLabels; 00107 } 00108 00109 struct SortResults { 00110 00111 bool operator()(const SVMLight::Result& rOne, const SVMLight::Result & rTwo) const { 00112 return (rOne.Value > rTwo.Value); 00113 } 00114 }; 00115 00116 size_t PrintResults(vector<SVMLight::Result> vecResults, ofstream & ofsm) { 00117 sort(vecResults.begin(), vecResults.end(), SortResults()); 00118 int LabelVal; 00119 for (size_t i = 0; i < vecResults.size(); i++) { 00120 ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t' 00121 << vecResults[i].Value << endl; 00122 } 00123 } 00124 ; 00125 00126 struct ParamStruct { 00127 vector<float> vecK, vecTradeoff; 00128 vector<size_t> vecLoss; 00129 vector<char*> vecNames; 00130 }; 00131 00132 ParamStruct ReadParamsFromFile(ifstream& ifsm, string outFile) { 00133 static const size_t c_iBuffer = 1024; 00134 char acBuffer[c_iBuffer]; 00135 char* nameBuffer; 00136 vector<string> vecstrTokens; 00137 size_t extPlace; 00138 string Ext, FileName; 00139 if ((extPlace = outFile.find_first_of(".")) != string::npos) { 00140 FileName = outFile.substr(0, extPlace); 00141 Ext = outFile.substr(extPlace, outFile.size()); 00142 } else { 00143 FileName = outFile; 00144 Ext = ""; 00145 } 00146 ParamStruct PStruct; 00147 size_t index = 0; 00148 while (!ifsm.eof()) { 00149 ifsm.getline(acBuffer, c_iBuffer - 1); 00150 acBuffer[c_iBuffer - 1] = 0; 00151 vecstrTokens.clear(); 00152 CMeta::Tokenize(acBuffer, vecstrTokens); 00153 if (vecstrTokens.empty()) 00154 continue; 00155 if (vecstrTokens.size() != 3) { 00156 cerr << "Illegal params line (" << vecstrTokens.size() << "): " 00157 << acBuffer << endl; 00158 continue; 00159 } 00160 if (acBuffer[0] == '#') { 00161 cerr << "skipping " << acBuffer << endl; 00162 } else { 00163 PStruct.vecLoss.push_back(atoi(vecstrTokens[0].c_str())); 00164 PStruct.vecTradeoff.push_back(atof(vecstrTokens[1].c_str())); 00165 PStruct.vecK.push_back(atof(vecstrTokens[2].c_str())); 00166 PStruct.vecNames.push_back(new char[c_iBuffer]); 00167 if (PStruct.vecLoss[index] == 4 || PStruct.vecLoss[index] == 5) 00168 sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f_k%4.3f%s", 00169 FileName.c_str(), PStruct.vecLoss[index], 00170 PStruct.vecTradeoff[index], PStruct.vecK[index], 00171 Ext.c_str()); 00172 else 00173 sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f%s", 00174 FileName.c_str(), PStruct.vecLoss[index], 00175 PStruct.vecTradeoff[index], Ext.c_str()); 00176 index++; 00177 } 00178 00179 } 00180 return PStruct; 00181 } 00182 00183 int main(int iArgs, char** aszArgs) { 00184 gengetopt_args_info sArgs; 00185 00186 CPCL PCL; 00187 SVMLight::CSVMPERF SVM; 00188 00189 size_t i, j, iGene, jGene; 00190 ifstream ifsm, iifsm; 00191 00192 if (cmdline_parser(iArgs, aszArgs, &sArgs)) { 00193 cmdline_parser_print_help(); 00194 return 1; 00195 } 00196 SVM.SetVerbosity(sArgs.verbosity_arg); 00197 SVM.SetLossFunction(sArgs.error_function_arg); 00198 if (sArgs.k_value_arg > 1) { 00199 cerr << "k_value is >1. Setting default 0.5" << endl; 00200 SVM.SetPrecisionFraction(0.5); 00201 } else if (sArgs.k_value_arg <= 0) { 00202 cerr << "k_value is <=0. Setting default 0.5" << endl; 00203 SVM.SetPrecisionFraction(0.5); 00204 } else { 00205 SVM.SetPrecisionFraction(sArgs.k_value_arg); 00206 } 00207 00208 00209 if (sArgs.cross_validation_arg < 1){ 00210 cerr << "cross_valid is <1. Must be set at least 1" << endl; 00211 return 1; 00212 } 00213 else if(sArgs.cross_validation_arg < 2){ 00214 cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl; 00215 } 00216 00217 SVM.SetTradeoff(sArgs.tradeoff_arg); 00218 if (sArgs.slack_flag) 00219 SVM.UseSlackRescaling(); 00220 else 00221 SVM.UseMarginRescaling(); 00222 00223 00224 if (!SVM.parms_check()) { 00225 cerr << "Sanity check failed, see above errors" << endl; 00226 return 1; 00227 } 00228 00229 if (!sArgs.output_given){ 00230 cerr << "output prefix not provided" << endl; 00231 return 1; 00232 } 00233 00234 string output_prefix(sArgs.output_arg); 00235 00236 // cout << "there are " << vecLabels.size() << " labels processed" << endl; 00237 size_t iFile; 00238 vector<string> PCLs; 00239 if (sArgs.input_given) { 00240 if (!PCL.Open(sArgs.input_arg, sArgs.skip_arg, sArgs.mmap_flag)) { 00241 cerr << "Could not open input PCL" << endl; 00242 return 1; 00243 } 00244 } 00245 00246 00247 vector< pair < string, string > > vecLabelLists; 00248 if (sArgs.labels_given) { 00249 ifsm.clear(); 00250 ifsm.open(sArgs.labels_arg); 00251 if (ifsm.is_open()) 00252 vecLabelLists = ReadLabelList(ifsm, output_prefix); 00253 else { 00254 cerr << "Could not read label list" << endl; 00255 return 1; 00256 } 00257 ifsm.close(); 00258 }else{ 00259 cerr << "list of labels not given" << endl; 00260 return 1; 00261 // if (sArgs.labels_given) { 00262 // vecLabelLists.push_back(pair(sArgs.labels_arg,sArgs.output_arg)) 00263 // } 00264 } 00265 size_t k; 00266 string labels_fn; 00267 string output_fn; 00268 00269 00270 SVMLight::SAMPLE* pTrainSample; 00271 vector<SVMLight::Result> AllResults; 00272 vector<SVMLight::Result> tmpAllResults; 00273 vector<SVMLight::SVMLabel> pTrainVector[sArgs.cross_validation_arg]; 00274 vector<SVMLight::SVMLabel> pTestVector[sArgs.cross_validation_arg]; 00275 vector<SVMLight::SVMLabel> vecLabels; 00276 00277 string out_fn; 00278 00279 for(k = 0; k < vecLabelLists.size(); k ++){ 00280 labels_fn = vecLabelLists[k].first; 00281 output_fn = vecLabelLists[k].second; 00282 00283 cout << labels_fn << endl; 00284 cout << output_fn << endl; 00285 00286 vecLabels.clear(); 00287 00288 ifsm.clear(); 00289 ifsm.open(labels_fn.c_str()); 00290 if (ifsm.is_open()) 00291 vecLabels = ReadLabels(ifsm); 00292 else { 00293 cerr << "Could not read label file" << endl; 00294 return 1; 00295 } 00296 ifsm.close(); 00297 00298 cout << "finished reading labels." << endl; 00299 00300 00301 //do learning and classifying with cross validation 00302 if( sArgs.cross_validation_arg > 1){ 00303 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00304 00305 pTestVector[i].clear(); 00306 pTrainVector[i].clear(); 00307 00308 pTestVector[i].reserve((size_t) vecLabels.size() 00309 / sArgs.cross_validation_arg + sArgs.cross_validation_arg); 00310 pTrainVector[i].reserve((size_t) vecLabels.size() 00311 / (sArgs.cross_validation_arg) 00312 * (sArgs.cross_validation_arg - 1) 00313 + sArgs.cross_validation_arg); 00314 for (j = 0; j < vecLabels.size(); j++) { 00315 if (j % sArgs.cross_validation_arg == i) { 00316 pTestVector[i].push_back(vecLabels[j]); 00317 } else { 00318 pTrainVector[i].push_back((vecLabels[j])); 00319 } 00320 } 00321 } 00322 } 00323 else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted 00324 00325 // no holdout so train is the same as test gene set 00326 pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00327 pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00328 00329 for (j = 0; j < vecLabels.size(); j++) { 00330 pTestVector[0].push_back(vecLabels[j]); 00331 pTrainVector[0].push_back(vecLabels[j]); 00332 } 00333 } 00334 00335 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00336 pTrainSample = SVMLight::CSVMPERF::CreateSample(PCL, 00337 pTrainVector[i]); 00338 00339 cerr << "Cross Validation Trial " << i << endl; 00340 00341 SVM.Learn(*pTrainSample); 00342 cerr << "Learned" << endl; 00343 tmpAllResults = SVM.Classify(PCL, 00344 pTestVector[i]); 00345 cerr << "Classified " << tmpAllResults.size() << " examples" 00346 << endl; 00347 AllResults.insert(AllResults.end(), tmpAllResults.begin(), 00348 tmpAllResults.end()); 00349 tmpAllResults.resize(0); 00350 00351 if (i > 0) { 00352 SVMLight::CSVMPERF::FreeSample(*pTrainSample); 00353 } 00354 } 00355 00356 ofstream ofsm; 00357 ofsm.clear(); 00358 out_fn = output_prefix + "/" + output_fn; 00359 ofsm.open(out_fn.c_str()); 00360 PrintResults(AllResults, ofsm); 00361 cout << "printed: " << output_fn << endl; 00362 00363 00364 delete[] pTrainSample; 00365 AllResults.clear(); 00366 tmpAllResults.clear(); 00367 vecLabels.clear(); 00368 00369 00370 00371 } 00372 } 00373