Sleipnir
|
00001 #include <fstream> 00002 00003 #include <vector> 00004 #include <queue> 00005 00006 /***************************************************************************** 00007 * This file is provided under the Creative Commons Attribution 3.0 license. 00008 * 00009 * You are free to share, copy, distribute, transmit, or adapt this work 00010 * PROVIDED THAT you attribute the work to the authors listed below. 00011 * For more information, please see the following web page: 00012 * http://creativecommons.org/licenses/by/3.0/ 00013 * 00014 * This file is a component of the Sleipnir library for functional genomics, 00015 * authored by: 00016 * Curtis Huttenhower (chuttenh@princeton.edu) 00017 * Mark Schroeder 00018 * Maria D. Chikina 00019 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00020 * 00021 * If you use this library, the included executable tools, or any related 00022 * code in your work, please cite the following publication: 00023 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00024 * Olga G. Troyanskaya. 00025 * "The Sleipnir library for computational functional genomics" 00026 *****************************************************************************/ 00027 #include "stdafx.h" 00028 #include "cmdline.h" 00029 #include "statistics.h" 00030 00031 using namespace SVMLight; 00032 00033 struct ParamStruct { 00034 vector<float> vecK, vecTradeoff; 00035 vector<size_t> vecLoss; 00036 vector<char*> vecNames; 00037 }; 00038 00039 ParamStruct ReadParamsFromFile(ifstream& ifsm, string outFile) { 00040 static const size_t c_iBuffer = 1024; 00041 char acBuffer[c_iBuffer]; 00042 char* nameBuffer; 00043 vector<string> vecstrTokens; 00044 size_t extPlace; 00045 string Ext, FileName; 00046 if ((extPlace = outFile.find_first_of(".")) != string::npos) { 00047 FileName = outFile.substr(0, extPlace); 00048 Ext = outFile.substr(extPlace, outFile.size()); 00049 } else { 00050 FileName = outFile; 00051 Ext = ""; 00052 } 00053 ParamStruct PStruct; 00054 size_t index = 0; 00055 while (!ifsm.eof()) { 00056 ifsm.getline(acBuffer, c_iBuffer - 1); 00057 acBuffer[c_iBuffer - 1] = 0; 00058 vecstrTokens.clear(); 00059 CMeta::Tokenize(acBuffer, vecstrTokens); 00060 if (vecstrTokens.empty()) 00061 continue; 00062 if (vecstrTokens.size() != 3) { 00063 cerr << "Illegal params line (" << vecstrTokens.size() << "): " 00064 << acBuffer << endl; 00065 continue; 00066 } 00067 if (acBuffer[0] == '#') { 00068 cerr << "skipping " << acBuffer << endl; 00069 } else { 00070 PStruct.vecLoss.push_back(atoi(vecstrTokens[0].c_str())); 00071 PStruct.vecTradeoff.push_back(atof(vecstrTokens[1].c_str())); 00072 PStruct.vecK.push_back(atof(vecstrTokens[2].c_str())); 00073 PStruct.vecNames.push_back(new char[c_iBuffer]); 00074 if (PStruct.vecLoss[index] == 4 || PStruct.vecLoss[index] == 5) 00075 sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f_k%4.3f%s", 00076 FileName.c_str(), PStruct.vecLoss[index], 00077 PStruct.vecTradeoff[index], PStruct.vecK[index], 00078 Ext.c_str()); 00079 else 00080 sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f%s", 00081 FileName.c_str(), PStruct.vecLoss[index], 00082 PStruct.vecTradeoff[index], Ext.c_str()); 00083 index++; 00084 } 00085 00086 } 00087 return PStruct; 00088 } 00089 00090 bool ReadModelFile(ifstream& ifsm, vector<float>& SModel) { 00091 static const size_t c_iBuffer = 1024; 00092 char acBuffer[c_iBuffer]; 00093 char* nameBuffer; 00094 vector<string> vecstrTokens; 00095 size_t extPlace; 00096 string Ext, FileName; 00097 size_t index = 0; 00098 00099 while (!ifsm.eof()) { 00100 ifsm.getline(acBuffer, c_iBuffer - 1); 00101 acBuffer[c_iBuffer - 1] = 0; 00102 vecstrTokens.clear(); 00103 CMeta::Tokenize(acBuffer, vecstrTokens); 00104 if (vecstrTokens.empty()) 00105 continue; 00106 if (vecstrTokens.size() > 1) { 00107 cerr << "Illegal model line (" << vecstrTokens.size() << "): " 00108 << acBuffer << endl; 00109 continue; 00110 } 00111 if (acBuffer[0] == '#') { 00112 cerr << "skipping " << acBuffer << endl; 00113 } else { 00114 SModel.push_back(atof(vecstrTokens[0].c_str())); 00115 } 00116 00117 00118 } 00119 ifsm.close(); 00120 } 00121 00122 // Read in the 00123 bool ReadProbParamFile(char* prob_file, float& A, float& B) { 00124 static const size_t c_iBuffer = 1024; 00125 char acBuffer[c_iBuffer]; 00126 char* nameBuffer; 00127 vector<string> vecstrTokens; 00128 size_t i, extPlace; 00129 string Ext, FileName; 00130 size_t index = 0; 00131 ifstream ifsm; 00132 00133 ifsm.open( prob_file ); 00134 i = 0; 00135 while (!ifsm.eof()) { 00136 ifsm.getline(acBuffer, c_iBuffer - 1); 00137 acBuffer[c_iBuffer - 1] = 0; 00138 vecstrTokens.clear(); 00139 CMeta::Tokenize(acBuffer, vecstrTokens); 00140 if (vecstrTokens.empty()) 00141 continue; 00142 if (vecstrTokens.size() > 1) { 00143 cerr << "Illegal model line (" << vecstrTokens.size() << "): " 00144 << acBuffer << endl; 00145 continue; 00146 } 00147 if (acBuffer[0] == '#') { 00148 cerr << "skipping " << acBuffer << endl; 00149 } else { 00150 if( i == 0 ) 00151 A = atof(vecstrTokens[0].c_str()); 00152 else if( i == 1 ) 00153 B = atof(vecstrTokens[0].c_str()); 00154 else{ 00155 cerr << "" << endl; 00156 return false; 00157 } 00158 i++; 00159 } 00160 } 00161 cerr << "Reading Prob file, A: " << A << ", B: " << B << endl; 00162 return true; 00163 } 00164 00165 // Platt's binary SVM Probablistic Output 00166 // Assume dec_values and labels have same dimensions and genes 00167 static void sigmoid_train(CDat& dec_values, 00168 CDat& labels, 00169 float& A, float& B){ 00170 double prior1=0, prior0 = 0; 00171 size_t i, j, idx, k; 00172 float d, lab; 00173 00174 int max_iter=100; // Maximal number of iterations 00175 double min_step=1e-10; // Minimal step taken in line search 00176 double sigma=1e-12; // For numerically strict PD of Hessian 00177 double eps=1e-5; 00178 vector<double> t; 00179 double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; 00180 double newA,newB,newf,d1,d2; 00181 int iter; 00182 00183 // Negatives are values less than 0 00184 for(i = 0; i < dec_values.GetGenes(); i++) 00185 for(j = (i+1); j < dec_values.GetGenes(); j++) 00186 if (!CMeta::IsNaN(d = dec_values.Get(i, j)) && !CMeta::IsNaN(lab = labels.Get(i, j)) ){ 00187 if(lab > 0) 00188 prior1 += 1; 00189 else if(lab < 0) 00190 prior0 += 1; 00191 } 00192 00193 // initialize size 00194 t.resize(prior0+prior1); 00195 00196 // Initial Point and Initial Fun Value 00197 A=0.0; B=log((prior0+1.0)/(prior1+1.0)); 00198 double hiTarget=(prior1+1.0)/(prior1+2.0); 00199 double loTarget=1/(prior0+2.0); 00200 double fval = 0.0; 00201 00202 for(idx = i = 0; idx < dec_values.GetGenes(); idx++) 00203 for(j = (idx+1); j < dec_values.GetGenes(); j++) 00204 if (!CMeta::IsNaN(d = dec_values.Get(idx, j)) && !CMeta::IsNaN(lab = labels.Get(idx, j)) ){ 00205 if (lab > 0 ) t[i]=hiTarget; 00206 else t[i]=loTarget; 00207 00208 fApB = d*A+B; 00209 if (fApB>=0) 00210 fval += t[i]*fApB + log(1+exp(-fApB)); 00211 else 00212 fval += (t[i] - 1)*fApB +log(1+exp(fApB)); 00213 ++i; 00214 } 00215 00216 for (iter=0;iter<max_iter;iter++){ 00217 // Update Gradient and Hessian (use H' = H + sigma I) 00218 h11=sigma; // numerically ensures strict PD 00219 h22=sigma; 00220 h21=0.0;g1=0.0;g2=0.0; 00221 00222 for(i = idx = 0; idx < dec_values.GetGenes(); idx++) 00223 for(j = (idx+1); j < dec_values.GetGenes(); j++) 00224 if (!CMeta::IsNaN(d = dec_values.Get(idx, j)) && !CMeta::IsNaN(lab = labels.Get(idx, j)) ){ 00225 fApB = d*A+B; 00226 00227 if (fApB >= 0){ 00228 p=exp(-fApB)/(1.0+exp(-fApB)); 00229 q=1.0/(1.0+exp(-fApB)); 00230 } 00231 else{ 00232 p=1.0/(1.0+exp(fApB)); 00233 q=exp(fApB)/(1.0+exp(fApB)); 00234 } 00235 d2=p*q; 00236 h11+=d*d*d2; 00237 h22+=d2; 00238 h21+=d*d2; 00239 d1=t[i]-p; 00240 g1+=d*d1; 00241 g2+=d1; 00242 00243 ++i; 00244 } 00245 00246 // Stopping Criteria 00247 if (fabs(g1)<eps && fabs(g2)<eps) 00248 break; 00249 00250 // Finding Newton direction: -inv(H') * g 00251 det=h11*h22-h21*h21; 00252 dA=-(h22*g1 - h21 * g2) / det; 00253 dB=-(-h21*g1+ h11 * g2) / det; 00254 gd=g1*dA+g2*dB; 00255 00256 stepsize = 1; // Line Search 00257 while (stepsize >= min_step){ 00258 newA = A + stepsize * dA; 00259 newB = B + stepsize * dB; 00260 00261 // New function value 00262 newf = 0.0; 00263 00264 for(i = idx = 0; idx < dec_values.GetGenes(); idx++) 00265 for(j = (idx+1); j < dec_values.GetGenes(); j++) 00266 if (!CMeta::IsNaN(d = dec_values.Get(idx, j)) && !CMeta::IsNaN(lab = labels.Get(idx, j)) ){ 00267 fApB = d*newA+newB; 00268 00269 if (fApB >= 0) 00270 newf += t[i]*fApB + log(1+exp(-fApB)); 00271 else 00272 newf += (t[i] - 1)*fApB +log(1+exp(fApB)); 00273 00274 ++i; 00275 } 00276 00277 // Check sufficient decrease 00278 if (newf<fval+0.0001*stepsize*gd){ 00279 A=newA;B=newB;fval=newf; 00280 break; 00281 } 00282 else 00283 stepsize = stepsize / 2.0; 00284 } 00285 00286 if (stepsize < min_step){ 00287 cerr << "Line search fails in two-class probability estimates: " << stepsize << ',' << min_step << endl; 00288 break; 00289 } 00290 } 00291 00292 if (iter>=max_iter) 00293 cerr << "Reaching maximal iterations in two-class probability estimates" << endl; 00294 } 00295 00296 static void sigmoid_predict(CDat& dec_values, float A, float B){ 00297 size_t i, j; 00298 float d, fApB; 00299 00300 for(i = 0; i < dec_values.GetGenes(); i++) 00301 for(j = (i+1); j < dec_values.GetGenes(); j++) 00302 if (!CMeta::IsNaN(d = dec_values.Get(i, j))){ 00303 fApB = d*A+B; 00304 // 1-p used later; avoid catastrophic cancellation 00305 if (fApB >= 0) 00306 dec_values.Set(i,j, exp(-fApB)/(1.0+exp(-fApB))); 00307 else 00308 dec_values.Set(i,j, 1.0/(1+exp(fApB))); 00309 } 00310 } 00311 00312 00313 int main(int iArgs, char** aszArgs) { 00314 gengetopt_args_info sArgs; 00315 SVMLight::CSVMPERF SVM; 00316 00317 size_t i, j, iGene, jGene, numpos, numneg, iSVM; 00318 ifstream ifsm; 00319 float d, sample_rate, dval; 00320 int iRet; 00321 map<string, size_t> mapstriZeros, mapstriDatasets; 00322 vector<string> vecstrDatasets; 00323 vector<bool> mapTgene; 00324 vector<bool> mapCgene; 00325 vector<size_t> mapTgene2fold; 00326 vector<int> tgeneCount; 00327 00328 DIR* dp; 00329 struct dirent* ep; 00330 CGenome Genome; 00331 CGenes Genes(Genome); 00332 00333 CGenome GenomeTwo; 00334 CGenes Context(GenomeTwo); 00335 00336 CGenome GenomeThree; 00337 CGenes Allgenes(GenomeThree); 00338 00339 if (cmdline_parser(iArgs, aszArgs, &sArgs)) { 00340 cmdline_parser_print_help(); 00341 return 1; 00342 } 00343 00344 CMeta Meta( sArgs.verbosity_arg, sArgs.random_arg ); 00345 00346 SVM.SetVerbosity(sArgs.verbosity_arg); 00347 SVM.SetLossFunction(sArgs.error_function_arg); 00348 if (sArgs.k_value_arg > 1) { 00349 cerr << "k_value is >1. Setting default 0.5" << endl; 00350 SVM.SetPrecisionFraction(0.5); 00351 } else if (sArgs.k_value_arg <= 0) { 00352 cerr << "k_value is <=0. Setting default 0.5" << endl; 00353 SVM.SetPrecisionFraction(0.5); 00354 } else { 00355 SVM.SetPrecisionFraction(sArgs.k_value_arg); 00356 } 00357 00358 00359 if (sArgs.cross_validation_arg < 1){ 00360 cerr << "cross_valid is <1. Must be set at least 1" << endl; 00361 return 1; 00362 } 00363 else if(sArgs.cross_validation_arg < 2){ 00364 cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl; 00365 } 00366 00367 SVM.SetTradeoff(sArgs.tradeoff_arg); 00368 if (sArgs.slack_flag) 00369 SVM.UseSlackRescaling(); 00370 else 00371 SVM.UseMarginRescaling(); 00372 00373 00374 if (!SVM.parms_check()) { 00375 cerr << "Sanity check failed, see above errors" << endl; 00376 return 1; 00377 } 00378 00379 // read in the list of datasets 00380 if(sArgs.directory_arg ) { 00381 dp = opendir (sArgs.directory_arg); 00382 if (dp != NULL){ 00383 while (ep = readdir (dp)){ 00384 // skip . .. files and temp files with ~ 00385 if (ep->d_name[0] == '.' || ep->d_name[strlen(ep->d_name)-1] == '~') 00386 continue; 00387 00388 // currently opens all files. Add filter here if want pick file extensions 00389 vecstrDatasets.push_back((string)sArgs.directory_arg + "/" + ep->d_name); 00390 } 00391 (void) closedir (dp); 00392 00393 cerr << "Input Dir contrains # datasets: " << vecstrDatasets.size() << '\n'; 00394 // sort datasets in alphabetical order 00395 std::sort(vecstrDatasets.begin(), vecstrDatasets.end()); 00396 } 00397 else{ 00398 cerr << "Couldn't open the directory: " << sArgs.directory_arg << '\n'; 00399 return 1; 00400 } 00401 } 00402 00403 // read target gene list 00404 if(sArgs.tgene_given ) { 00405 ifstream ifsm; 00406 ifsm.open(sArgs.tgene_arg); 00407 00408 if (!Genes.Open(ifsm)) { 00409 cerr << "Could not open: " << sArgs.tgene_arg << endl; 00410 return 1; 00411 } 00412 ifsm.close(); 00413 } 00414 00415 // read context gene list 00416 if(sArgs.context_given ) { 00417 ifstream ifsm; 00418 ifsm.open(sArgs.context_arg); 00419 00420 if (!Context.Open(ifsm)) { 00421 cerr << "Could not open: " << sArgs.context_arg << endl; 00422 return 1; 00423 } 00424 ifsm.close(); 00425 } 00426 00427 // read all gene list 00428 // IF given this flag predict for all gene pairs 00429 if(sArgs.allgenes_given ) { 00430 ifstream ifsm; 00431 ifsm.open(sArgs.allgenes_arg); 00432 00433 if (!Allgenes.Open(ifsm)) { 00434 cerr << "Could not open: " << sArgs.allgenes_arg << endl; 00435 return 1; 00436 } 00437 ifsm.close(); 00438 } 00439 00441 // Chris added 00442 vector<SVMLight::SVMLabelPair*> vecLabels; 00443 CDat Labels; 00444 CDat Results; 00445 00446 00450 if ( sArgs.labels_given ) { 00451 if (!Labels.Open(sArgs.labels_arg, sArgs.mmap_flag)) { 00452 cerr << "Could not open input labels Dat" << endl; 00453 return 1; 00454 } 00455 00456 // random sample labels 00457 if( sArgs.subsample_given ){ 00458 cerr << "Sub-sample labels to rate:" << sArgs.subsample_arg << endl; 00459 for( i = 0; i < Labels.GetGenes( ); ++i ) 00460 for( j = ( i + 1 ); j < Labels.GetGenes( ); ++j ) 00461 if( !CMeta::IsNaN( Labels.Get( i, j ) ) && 00462 ( ( (float)rand( ) / RAND_MAX ) > sArgs.subsample_arg ) ) 00463 Labels.Set( i, j, CMeta::GetNaN( ) ); 00464 } 00465 00466 // set all NaN values to negatives 00467 if( sArgs.nan2neg_given ){ 00468 cerr << "Set NaN labels dat as negatives" << endl; 00469 00470 for(i = 0; i < Labels.GetGenes(); i++) 00471 for(j = (i+1); j < Labels.GetGenes(); j++) 00472 if (CMeta::IsNaN(d = Labels.Get(i, j))) 00473 Labels.Set(i, j, -1); 00474 } 00475 00476 if( sArgs.tgene_given ){ 00477 mapTgene.resize(Labels.GetGenes()); 00478 00479 for(i = 0; i < Labels.GetGenes(); i++){ 00480 if(Genes.GetGene(Labels.GetGene(i)) == -1) 00481 mapTgene[i] = false; 00482 else 00483 mapTgene[i] = true; 00484 } 00485 00486 // keep track of positive gene counts 00487 tgeneCount.resize(Labels.GetGenes()); 00488 00489 // if given a target gene file 00490 // Only keep eges that have only one gene in this targe gene list 00491 if( sArgs.onetgene_flag ){ 00492 cerr << "Filtering to only include edges with one gene in gene file: " << sArgs.tgene_arg << endl; 00493 00494 for(i = 0; i < Labels.GetGenes(); i++) 00495 for(j = (i+1); j < Labels.GetGenes(); j++) 00496 if (!CMeta::IsNaN(d = Labels.Get(i, j))){ 00497 if(mapTgene[i] && mapTgene[j]) 00498 Labels.Set(i, j, CMeta::GetNaN()); 00499 else if(!mapTgene[i] && !mapTgene[j]) 00500 Labels.Set(i, j, CMeta::GetNaN()); 00501 } 00502 } 00503 } // if edgeholdout flag not given, we are doing gene holdout by default. 00504 // Since target gene list was not given we are using all genes in labels as target genes to cross holdout 00505 else if( !sArgs.edgeholdout_flag ){ 00506 mapTgene.resize(Labels.GetGenes()); 00507 00508 // all genes are target genes 00509 for(i = 0; i < Labels.GetGenes(); i++){ 00510 mapTgene[i] = true; 00511 } 00512 00513 // keep track of positive gene counts 00514 tgeneCount.resize(Labels.GetGenes()); 00515 } 00516 00517 //if given a context map the context genes 00518 if( sArgs.context_given ){ 00519 mapCgene.resize(Labels.GetGenes()); 00520 00521 for(i = 0; i < Labels.GetGenes(); i++){ 00522 if(Context.GetGene(Labels.GetGene(i)) == -1) 00523 mapCgene[i] = false; 00524 else 00525 mapCgene[i] = true; 00526 } 00527 } 00528 00529 // Set target prior 00530 if(sArgs.prior_given){ 00531 numpos = 0; 00532 numneg = 0; 00533 for(i = 0; i < Labels.GetGenes(); i++) 00534 for(j = (i+1); j < Labels.GetGenes(); j++) 00535 if (!CMeta::IsNaN(d = Labels.Get(i, j))){ 00536 if(d > 0){ 00537 ++numpos;} 00538 else if(d < 0){ 00539 ++numneg; 00540 } 00541 } 00542 00543 if( ((float)numpos / (numpos + numneg)) < sArgs.prior_arg){ 00544 00545 cerr << "Convert prior from orig: " << ((float)numpos / (numpos + numneg)) << " to target: " << sArgs.prior_arg << endl; 00546 00547 sample_rate = ((float)numpos / (numpos + numneg)) / sArgs.prior_arg; 00548 00549 // remove neg labels to reach prior 00550 for(i = 0; i < Labels.GetGenes(); i++) 00551 for(j = (i+1); j < Labels.GetGenes(); j++) 00552 if (!CMeta::IsNaN(d = Labels.Get(i, j)) && d < 0){ 00553 if((float)rand() / RAND_MAX > sample_rate) 00554 Labels.Set(i, j, CMeta::GetNaN()); 00555 } 00556 } 00557 } 00558 00559 // output sample labels for eval/debug purpose 00560 if(sArgs.OutLabels_given){ 00561 cerr << "save sampled labels as (1,0)s to: " << sArgs.OutLabels_arg << endl; 00562 Labels.Normalize( CDat::ENormalizeMinMax ); 00563 Labels.Save(sArgs.OutLabels_arg); 00564 return 0; 00565 } 00566 00567 // Exclude labels without context genes 00568 if(sArgs.context_given ) 00569 Labels.FilterGenes( Context, CDat::EFilterInclude ); 00570 00571 00572 // If not given a SVM model/models we are in learning mode, thus construct each SVMLabel object for label 00573 if( !sArgs.model_given && !sArgs.modelPrefix_given ){ 00574 numpos = 0; 00575 for(i = 0; i < Labels.GetGenes(); i++) 00576 for(j = (i+1); j < Labels.GetGenes(); j++) 00577 if (!CMeta::IsNaN(d = Labels.Get(i, j))){ 00578 if (d != 0) 00579 vecLabels.push_back(new SVMLight::SVMLabelPair(d, i, j)); 00580 if(d > 0) 00581 ++numpos; 00582 } 00583 00584 // check to see if you have enough positives to learn from 00585 if(sArgs.mintrain_given && sArgs.mintrain_arg > numpos){ 00586 cerr << "Not enough positive examples from: " << sArgs.labels_arg << " numpos: " << numpos << endl; 00587 return 1; 00588 } 00589 } 00590 00591 // save sampled labels 00592 if(sArgs.SampledLabels_given) { 00593 cerr << "Save sampled labels file: " << sArgs.SampledLabels_arg << endl; 00594 Labels.Save(sArgs.SampledLabels_arg); 00595 } 00596 } 00597 00598 SVMLight::SAMPLE* pTrainSample; 00599 SVMLight::SAMPLE* pAllSample; 00600 vector<SVMLight::SVMLabelPair*> pTrainVector[sArgs.cross_validation_arg]; 00601 vector<SVMLight::SVMLabelPair*> pTestVector[sArgs.cross_validation_arg]; 00602 00603 // ################################### 00604 // 00605 // Now conduct Learning if given a labels and no SVM models 00606 // 00607 if (sArgs.output_given && sArgs.labels_given && !sArgs.modelPrefix_given && !sArgs.model_given) { 00608 //do learning and classifying with cross validation 00609 if( sArgs.cross_validation_arg > 1){ 00610 cerr << "setting cross validation holds" << endl; 00611 00612 mapTgene2fold.resize(mapTgene.size()); 00613 00614 // assign target genes to there cross validation fold 00615 if(sArgs.tgene_given || !sArgs.edgeholdout_flag){ 00616 for(i = 0; i < mapTgene.size(); i++){ 00617 if(!mapTgene[i]){ 00618 mapTgene2fold[i] = -1; 00619 continue; 00620 } 00621 //cerr << "what's up?" << endl; 00622 mapTgene2fold[i] = rand() % sArgs.cross_validation_arg; 00623 } 00624 00625 // cross-fold by target gene 00626 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00627 cerr << "cross validation holds setup:" << i << endl; 00628 00629 // keep track of positive gene counts 00630 if(sArgs.balance_flag){ 00631 cerr << "Set up balance: " << i << endl; 00632 for(j = 0; j < Labels.GetGenes(); j++) 00633 tgeneCount[j] = 0; 00634 00635 for(j = 0; j < vecLabels.size(); j++) 00636 if(vecLabels[j]->Target > 0){ 00637 ++(tgeneCount[vecLabels[j]->iidx]); 00638 ++(tgeneCount[vecLabels[j]->jidx]); 00639 } 00640 00641 if(sArgs.bfactor_given) 00642 for(j = 0; j < vecLabels.size(); j++) 00643 if(tgeneCount[vecLabels[j]->jidx] < 500) 00644 tgeneCount[vecLabels[j]->jidx] = sArgs.bfactor_arg*tgeneCount[vecLabels[j]->jidx]; 00645 } 00646 00647 for (j = 0; j < vecLabels.size(); j++) { 00648 //if( j % 1000 == 0) 00649 //cerr << "cross validation push labels:" << j << endl; 00650 00651 // assume only one gene is a target gene in a edge 00652 if(mapTgene[vecLabels[j]->iidx]){ 00653 if(vecLabels[j]->Target < 0){ 00654 --(tgeneCount[vecLabels[j]->iidx]); 00655 } 00656 00657 if(mapTgene2fold[vecLabels[j]->iidx] == i) 00658 pTestVector[i].push_back(vecLabels[j]); 00659 else{ 00660 //cerr << tgeneCount[vecLabels[j]->iidx] << endl; 00661 00662 if( sArgs.balance_flag && vecLabels[j]->Target < 0 && tgeneCount[vecLabels[j]->iidx] < 0){ 00663 continue; 00664 } 00665 00666 // only add if both genes are in context 00667 if( sArgs.context_given && ( !mapCgene[vecLabels[j]->iidx] || !mapCgene[vecLabels[j]->jidx])) 00668 continue; 00669 00670 pTrainVector[i].push_back(vecLabels[j]); 00671 } 00672 } 00673 else if(mapTgene[vecLabels[j]->jidx]){ 00674 if(vecLabels[j]->Target < 0) 00675 --(tgeneCount[vecLabels[j]->jidx]); 00676 00677 if(mapTgene2fold[vecLabels[j]->jidx] == i) 00678 pTestVector[i].push_back(vecLabels[j]); 00679 else{ 00680 //cerr << tgeneCount[vecLabels[j]->jidx] << endl; 00681 00682 if( sArgs.balance_flag && vecLabels[j]->Target < 0 && tgeneCount[vecLabels[j]->jidx] < 0){ 00683 continue; 00684 } 00685 00686 // only add if both genes are in context 00687 if( sArgs.context_given && ( !mapCgene[vecLabels[j]->iidx] || !mapCgene[vecLabels[j]->jidx])) 00688 continue; 00689 00690 pTrainVector[i].push_back(vecLabels[j]); 00691 } 00692 } 00693 else{ 00694 cerr << "Error: edge exist without a target gene" << endl; return 1; 00695 } 00696 } 00697 00698 cerr << "test,"<< i <<": " << pTestVector[i].size() << endl; 00699 int numpos = 0; 00700 for(j=0; j < pTrainVector[i].size(); j++) 00701 if(pTrainVector[i][j]->Target > 0) 00702 ++numpos; 00703 00704 if( numpos < 1 || (sArgs.mintrain_given && sArgs.mintrain_arg > numpos) ){ 00705 cerr << "Not enough positive examples from fold: " << i << " file: " << sArgs.labels_arg << " numpos: " << numpos << endl; 00706 return 1; 00707 } 00708 00709 cerr << "train,"<< i <<","<<numpos<<": " << pTrainVector[i].size() << endl; 00710 00711 } 00712 } 00713 else{ //randomly set eges into cross-fold 00714 if( sArgs.context_given ){ 00715 cerr << "context not implemented yet for random edge holdout" << endl; 00716 return 1; 00717 } 00718 00719 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00720 pTestVector[i].reserve((size_t) vecLabels.size() 00721 / sArgs.cross_validation_arg + sArgs.cross_validation_arg); 00722 pTrainVector[i].reserve((size_t) vecLabels.size() 00723 / (sArgs.cross_validation_arg) 00724 * (sArgs.cross_validation_arg - 1) 00725 + sArgs.cross_validation_arg); 00726 for (j = 0; j < vecLabels.size(); j++) { 00727 if (j % sArgs.cross_validation_arg == i) { 00728 pTestVector[i].push_back(vecLabels[j]); 00729 } else { 00730 pTrainVector[i].push_back((vecLabels[j])); 00731 } 00732 } 00733 } 00734 } 00735 } 00736 else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted 00737 00738 // no holdout so train is the same as test gene set 00739 pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00740 pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg); 00741 00742 for (j = 0; j < vecLabels.size(); j++) { 00743 pTestVector[0].push_back(vecLabels[j]); 00744 pTrainVector[0].push_back(vecLabels[j]); 00745 } 00746 } 00747 00748 // initalize the results 00749 Results.Open(Labels.GetGeneNames(), true); 00750 00752 // Create feature vectors for all Label pairs using input datasets 00753 // 00754 cerr << "CreateDocs!"<< endl; 00755 if(sArgs.normalizeZero_flag){ 00756 SVMLight::CSVMPERF::CreateDoc(vecstrDatasets, 00757 vecLabels, 00758 Labels.GetGeneNames(), 00759 Sleipnir::CDat::ENormalizeMinMax); 00760 }else if(sArgs.normalizeNPone_flag){ 00761 SVMLight::CSVMPERF::CreateDoc(vecstrDatasets, 00762 vecLabels, 00763 Labels.GetGeneNames(), 00764 Sleipnir::CDat::ENormalizeMinMaxNPone); 00765 }else{ 00766 SVMLight::CSVMPERF::CreateDoc(vecstrDatasets, 00767 vecLabels, 00768 Labels.GetGeneNames()); 00769 } 00770 00772 // Start learning for each cross validation fold 00773 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00774 std::stringstream sstm; 00775 00776 // build up the output SVM model file name 00777 if(sArgs.context_given){ 00778 std::string path(sArgs.context_arg); 00779 size_t pos = path.find_last_of("/"); 00780 std::string cname; 00781 if(pos != std::string::npos) 00782 cname.assign(path.begin() + pos + 1, path.end()); 00783 else 00784 cname = path; 00785 00786 sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << cname << "." << i << ".svm"; 00787 }else 00788 sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << i << ".svm"; 00789 00790 cerr << "Cross validation fold: " << i << endl; 00791 pTrainSample = SVMLight::CSVMPERF::CreateSample(pTrainVector[i]); 00792 00793 // Skip learning if SVM model file already exist 00794 if( sArgs.skipSVM_flag && 00795 access((char*)(sstm.str().c_str()), R_OK ) != -1 00796 ){ 00797 //SVM.ReadModel((char*)(sstm.str().c_str())); 00798 SVM.ReadModelLinear((char*)(sstm.str().c_str())); 00799 cerr << "Using existing trained SVM model: " << sstm.str() << endl; 00800 }else{ 00801 SVM.Learn(*pTrainSample); 00802 cerr << "SVM model Learned" << endl; 00803 } 00804 00805 SVM.Classify(Results, 00806 pTestVector[i]); 00807 00808 cerr << "SVM model classified holdout" << endl; 00809 00810 if( sArgs.savemodel_flag ){ 00811 SVM.WriteModel((char*)(sstm.str().c_str())); 00812 } 00813 00814 // DEBUG 00815 SVMLight::CSVMPERF::FreeSample_leave_Doc(*pTrainSample); 00816 free(pTrainSample); 00817 } 00818 00819 if( sArgs.prob_flag ){ 00820 cerr << "Converting prediction values to estimated probablity" << endl; 00821 float A, B; 00822 00823 // TODO add function to read in prob parameter file if already existing 00824 sigmoid_train(Results, Labels, A, B); 00825 sigmoid_predict(Results, A, B); 00826 } 00827 else if( sArgs.probCross_flag ){ 00828 float A, B; 00829 size_t k, ctrain, itrain; 00830 vector<SVMLight::SVMLabelPair*> probTrainVector; 00831 00832 for (i = 0; i < sArgs.cross_validation_arg; i++) { 00833 cerr << "Convert to probability for cross fold: " << i << endl; 00834 00835 // construct prob file name 00836 std::stringstream pstm; 00837 ofstream ofsm; 00838 if(sArgs.context_given){ 00839 std::string path(sArgs.context_arg); 00840 size_t pos = path.find_last_of("/"); 00841 std::string cname; 00842 if(pos != std::string::npos) 00843 cname.assign(path.begin() + pos + 1, path.end()); 00844 else 00845 cname = path; 00846 00847 pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << cname << "." << i << ".svm.prob"; 00848 }else 00849 pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << i << ".svm.prob"; 00850 00851 00852 if( sArgs.skipSVM_flag && 00853 access((char*)(pstm.str().c_str()), R_OK ) != -1 00854 ){ 00855 00856 // read in parameter file 00857 if(!ReadProbParamFile((char*)(pstm.str().c_str()), A, B)){ 00858 cerr << "Failed to read Probablity parameter file: " << pstm.str() << endl; 00859 return 1; 00860 } 00861 cerr << pstm.str() << ": read in values, A: " << A << ", B: " << B << endl; 00862 00863 }else{ 00864 ctrain = 0; 00865 for (j = 0; j < sArgs.cross_validation_arg; j++) { 00866 if(i == j) 00867 continue; 00868 ctrain += pTrainVector[j].size(); 00869 } 00870 00871 probTrainVector.resize(ctrain); 00872 itrain = 0; 00873 for (j = 0; j < sArgs.cross_validation_arg; j++) { 00874 if(i == j) 00875 continue; 00876 for(k = 0; k < pTrainVector[j].size(); k++){ 00877 probTrainVector[itrain] = pTrainVector[j][k]; 00878 itrain += 1; 00879 } 00880 } 00881 00882 // train A,B sigmoid perameters 00883 SVM.sigmoid_train(Results, probTrainVector, A, B); 00884 } 00885 00886 SVM.sigmoid_predict(Results, pTestVector[i], A, B); 00887 00888 // open prob param file 00889 if(sArgs.savemodel_flag){ 00890 ofsm.open(pstm.str().c_str()); 00891 ofsm << A << endl; 00892 ofsm << B << endl; 00893 ofsm.close(); 00894 } 00895 } 00896 } 00897 00898 // only save cross-validated results when not predicting all genes 00899 if( !sArgs.allgenes_given ) 00900 Results.Save(sArgs.output_arg); 00901 } 00902 00904 // If given all genes arg, this puts in prediction mode for all gene pairs from the gene list 00905 // 00906 if ( sArgs.allgenes_given && sArgs.output_given) { //read model and classify all 00907 size_t iData; 00908 vector<vector<float> > vecSModel; 00909 vecSModel.resize(sArgs.cross_validation_arg); 00910 ifstream mifsm; 00911 vector<CDat* > vecResults; 00912 vector<size_t> veciGene; 00913 00914 cerr << "Predicting for all genes given." << endl; 00915 00916 // open SVM models for prefix file 00917 if(sArgs.modelPrefix_given){ 00918 for(i = 0; i < sArgs.cross_validation_arg; i++){ 00919 std::stringstream sstm; 00920 00921 sstm << sArgs.modelPrefix_arg << "." << i << ".svm"; 00922 if( access((char*)(sstm.str().c_str()), R_OK) == -1 ){ 00923 cerr << "ERROR: SVM model file cannot be opned: " << sstm.str() << endl; 00924 return 1; 00925 } 00926 00927 mifsm.open((char*)(sstm.str().c_str())); 00928 ReadModelFile(mifsm, vecSModel[i]); 00929 } 00930 }else if( sArgs.model_given ){ // open SVM model from file 00931 //vector<float> SModel; 00932 vecSModel.resize(1); 00933 00934 if( access(sArgs.model_arg, R_OK) == -1 ){ 00935 cerr << "ERROR: SVM model file cannot be opned: " << sArgs.model_arg << endl; 00936 return 1; 00937 } 00938 00939 mifsm.open(sArgs.model_arg); 00940 ReadModelFile(mifsm, vecSModel[0]); 00941 00942 // DEBUG check if this is ok 00943 sArgs.cross_validation_arg = 1; 00944 }else{ 00945 // open SVM model file from which was just trained 00946 for(i = 0; i < sArgs.cross_validation_arg; i++){ 00947 std::stringstream sstm; 00948 00949 if(sArgs.context_given){ 00950 std::string path(sArgs.context_arg); 00951 size_t pos = path.find_last_of("/"); 00952 std::string cname; 00953 if(pos != std::string::npos) 00954 cname.assign(path.begin() + pos + 1, path.end()); 00955 else 00956 cname = path; 00957 00958 sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << cname << "." << i << ".svm"; 00959 }else 00960 sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << i << ".svm"; 00961 00962 00963 if( access((char*)(sstm.str().c_str()), R_OK) == -1 ){ 00964 cerr << "ERROR: SVM model file cannot be opned: " << sstm.str() << endl; 00965 return 1; 00966 } 00967 00968 mifsm.open((char*)(sstm.str().c_str())); 00969 ReadModelFile(mifsm, vecSModel[i]); 00970 } 00971 } 00972 00973 // Initialize output for input all gene list 00974 vecResults.resize(sArgs.cross_validation_arg); 00975 for(i = 0; i < sArgs.cross_validation_arg; i++){ 00976 vecResults[ i ] = new CDat(); 00977 vecResults[ i ]->Open(Allgenes.GetGeneNames( ), true); 00978 } 00979 00980 00981 // Now iterate over all datasets to make predictions for all gene pairs 00982 CDat wDat; 00983 for(iData = 0; iData < vecstrDatasets.size(); iData++){ 00984 if(!wDat.Open(vecstrDatasets[iData].c_str(), sArgs.mmap_flag)) { 00985 cerr << vecstrDatasets[iData].c_str() << endl; 00986 cerr << "Could not open: " << vecstrDatasets[iData] << endl; 00987 return 1; 00988 } 00989 00990 cerr << "Open: " << vecstrDatasets[iData] << endl; 00991 00992 // normalize data file 00993 if(sArgs.normalizeZero_flag){ 00994 cerr << "Normalize input [0,1] data" << endl; 00995 wDat.Normalize( Sleipnir::CDat::ENormalizeMinMax ); 00996 }else if(sArgs.normalizeNPone_flag){ 00997 cerr << "Normalize input [-1,1] data" << endl; 00998 wDat.Normalize( Sleipnir::CDat::ENormalizeMinMaxNPone ); 00999 } 01000 01001 // map result gene list to dataset gene list 01002 veciGene.resize( vecResults[ 0 ]->GetGenes() ); 01003 for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){ 01004 veciGene[ i ] = wDat.GetGene( vecResults[ 0 ]->GetGene( i ) ); 01005 } 01006 01007 // compute prediction component for this dataset/SVM model 01008 for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){ 01009 iGene = veciGene[i]; 01010 01011 for(j = i+1; j < vecResults[ 0 ]->GetGenes(); j++){ 01012 jGene = veciGene[j]; 01013 01014 // if no value, set feature to 0 01015 if( iGene == -1 || jGene == -1 || CMeta::IsNaN(d = wDat.Get(iGene, jGene) ) ) 01016 d = 0; 01017 01018 // iterate each SVM model 01019 for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){ 01020 if( CMeta::IsNaN(dval = vecResults[ iSVM ]->Get(i, j)) ) 01021 vecResults[ iSVM ]->Set(i, j, (d * vecSModel[iSVM][iData]) ); 01022 else 01023 vecResults[ iSVM ]->Set(i, j, (dval + (d * vecSModel[iSVM][iData])) ); 01024 01025 } 01026 01027 } 01028 } 01029 } 01030 01031 01032 // convert the SVM predictions for each model to Probablity if required 01033 if( sArgs.prob_flag || sArgs.probCross_flag ){ 01034 // iterate over each SVM model and its output and convert to probablity 01035 float A; 01036 float B; 01037 01038 cerr << "convert prediction dabs to probablity" << endl; 01039 01040 for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){ 01041 // Read A,B perameters 01042 std::stringstream pstm; 01043 if(sArgs.modelPrefix_given){ 01044 pstm << sArgs.modelPrefix_arg << "." << iSVM << ".svm.prob"; 01045 }else if( sArgs.model_given ){ // open SVM model from file 01046 01047 // DEBUG, this might need to looked at!!! 01048 pstm << sArgs.model_arg << ".prob"; 01049 01050 }else{ // open SVM model file from which was just trained 01051 if(sArgs.context_given){ 01052 std::string path(sArgs.context_arg); 01053 size_t pos = path.find_last_of("/"); 01054 std::string cname; 01055 if(pos != std::string::npos) 01056 cname.assign(path.begin() + pos + 1, path.end()); 01057 else 01058 cname = path; 01059 01060 pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << cname << "." << iSVM << ".svm.prob"; 01061 }else 01062 pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg << "." << iSVM << ".svm.prob"; 01063 } 01064 01065 // read in parameter file 01066 if(!ReadProbParamFile((char*)(pstm.str().c_str()), A, B)){ 01067 cerr << "Failed to read Probablity parameter file: " << pstm.str() << endl; 01068 return 1; 01069 } 01070 01071 cerr << pstm.str() << ", A: " << A << ",B: " << B << endl; 01072 01073 // debug 01074 //cerr << "known:" << vecResults[ iSVM ]->Get(x,y) << endl; 01075 01076 // now convert the SVM model prediction values to probablity based on param A,B 01077 sigmoid_predict( *(vecResults[ iSVM ]), A, B); 01078 01079 //cerr << "known PROB:" << vecResults[ iSVM ]->Get(x,y) << endl; 01080 } 01081 } 01082 01083 // filter results 01084 if( sArgs.tgene_given ){ 01085 for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){ 01086 vecResults[ iSVM ]->FilterGenes( sArgs.tgene_arg, CDat::EFilterEdge ); 01087 } 01088 } 01089 01090 // Exclude pairs without context genes 01091 if(sArgs.context_given ){ 01092 for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){ 01093 vecResults[ iSVM ]->FilterGenes( Context, CDat::EFilterInclude ); 01094 } 01095 } 01096 01097 // Take average of prediction value from each model 01098 // The final results will be stored/overwritten into the first dab (i.e. vecResults[ 0 ]) 01099 for(i = 0; i < vecResults[ 0 ]->GetGenes(); i ++) 01100 for(j = i+1; j < vecResults[ 0 ]->GetGenes(); j++){ 01101 if (CMeta::IsNaN(dval = vecResults[ 0 ]->Get(i, j))) 01102 continue; 01103 01104 // start from the second 01105 // Assume all SVM model prediction dabs have identical NaN locations 01106 if(sArgs.aggregateMax_flag){ // take the maximum prediction value 01107 float maxval = dval; 01108 for(iSVM = 1; iSVM < sArgs.cross_validation_arg; iSVM++) 01109 if( vecResults[ iSVM ]->Get(i, j) > dval ) 01110 dval = vecResults[ iSVM ]->Get(i, j); 01111 01112 vecResults[ 0 ]->Set(i, j, dval); 01113 }else{ // Average over prediction values 01114 for(iSVM = 1; iSVM < sArgs.cross_validation_arg; iSVM++) 01115 dval += vecResults[ iSVM ]->Get(i, j); 01116 01117 vecResults[ 0 ]->Set(i, j, (dval / sArgs.cross_validation_arg) ); 01118 } 01119 } 01120 01121 // Replace gene-pair prediction values with labels from the cross-validation result 01122 // This basically replaces the prediction value with one prediction value with which this label was heldout 01123 // Only do this if cross-validation was conducted or given a replacement dab 01124 if( (!sArgs.NoCrossPredict_flag && sArgs.output_given && sArgs.labels_given && !sArgs.modelPrefix_given && !sArgs.model_given) || 01125 sArgs.CrossResult_given ){ 01126 01127 if( sArgs.CrossResult_given ){ 01128 if(!Results.Open(sArgs.CrossResult_arg, sArgs.mmap_flag)) { 01129 cerr << "Could not open: " << sArgs.CrossResult_arg << endl; 01130 return 1; 01131 } 01132 } 01133 01134 cerr << "Label pairs set to cross-validation values" << endl; 01135 01136 // map result gene list to dataset gene list 01137 veciGene.resize( vecResults[ 0 ]->GetGenes() ); 01138 for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){ 01139 veciGene[ i ] = Results.GetGene( vecResults[ 0 ]->GetGene( i ) ); 01140 } 01141 01142 // compute prediction component for this dataset/SVM model 01143 for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){ 01144 if( (iGene = veciGene[i]) == -1 ) 01145 continue; 01146 01147 for(j = i+1; j < vecResults[ 0 ]->GetGenes(); j++){ 01148 if( (jGene = veciGene[j]) == -1 ) 01149 continue; 01150 01151 if( CMeta::IsNaN(d = Results.Get(iGene, jGene) ) ) 01152 continue; 01153 01154 vecResults[ 0 ]->Set(i, j, d); 01155 } 01156 } 01157 } 01158 01159 // now save the averged prediction values 01160 vecResults[ 0 ]->Save(sArgs.output_arg); 01161 } 01162 01163 } 01164