Sleipnir
tools/SVMperfing/SVMperfing.cpp
00001 #include <fstream>
00002 
00003 #include <vector>
00004 #include <queue>
00005 
00006 /*****************************************************************************
00007  * This file is provided under the Creative Commons Attribution 3.0 license.
00008  *
00009  * You are free to share, copy, distribute, transmit, or adapt this work
00010  * PROVIDED THAT you attribute the work to the authors listed below.
00011  * For more information, please see the following web page:
00012  * http://creativecommons.org/licenses/by/3.0/
00013  *
00014  * This file is a component of the Sleipnir library for functional genomics,
00015  * authored by:
00016  * Curtis Huttenhower (chuttenh@princeton.edu)
00017  * Mark Schroeder
00018  * Maria D. Chikina
00019  * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00020  *
00021  * If you use this library, the included executable tools, or any related
00022  * code in your work, please cite the following publication:
00023  * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00024  * Olga G. Troyanskaya.
00025  * "The Sleipnir library for computational functional genomics"
00026  *****************************************************************************/
00027 #include "stdafx.h"
00028 #include "cmdline.h"
00029 #include "statistics.h"
00030 
00031 using namespace SVMLight;
00032 
00033 struct ParamStruct {
00034     vector<float> vecK, vecTradeoff;
00035     vector<size_t> vecLoss;
00036     vector<char*> vecNames;
00037 };
00038 
00039 ParamStruct ReadParamsFromFile(ifstream& ifsm, string outFile) {
00040     static const size_t c_iBuffer = 1024;
00041     char acBuffer[c_iBuffer];
00042     char* nameBuffer;
00043     vector<string> vecstrTokens;
00044     size_t extPlace;
00045     string Ext, FileName;
00046     if ((extPlace = outFile.find_first_of(".")) != string::npos) {
00047         FileName = outFile.substr(0, extPlace);
00048         Ext = outFile.substr(extPlace, outFile.size());
00049     } else {
00050         FileName = outFile;
00051         Ext = "";
00052     }
00053     ParamStruct PStruct;
00054     size_t index = 0;
00055     while (!ifsm.eof()) {
00056         ifsm.getline(acBuffer, c_iBuffer - 1);
00057         acBuffer[c_iBuffer - 1] = 0;
00058         vecstrTokens.clear();
00059         CMeta::Tokenize(acBuffer, vecstrTokens);
00060         if (vecstrTokens.empty())
00061             continue;
00062         if (vecstrTokens.size() != 3) {
00063             cerr << "Illegal params line (" << vecstrTokens.size() << "): "
00064                     << acBuffer << endl;
00065             continue;
00066         }
00067         if (acBuffer[0] == '#') {
00068             cerr << "skipping " << acBuffer << endl;
00069         } else {
00070             PStruct.vecLoss.push_back(atoi(vecstrTokens[0].c_str()));
00071             PStruct.vecTradeoff.push_back(atof(vecstrTokens[1].c_str()));
00072             PStruct.vecK.push_back(atof(vecstrTokens[2].c_str()));
00073             PStruct.vecNames.push_back(new char[c_iBuffer]);
00074             if (PStruct.vecLoss[index] == 4 || PStruct.vecLoss[index] == 5)
00075                 sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f_k%4.3f%s",
00076                         FileName.c_str(), PStruct.vecLoss[index],
00077                         PStruct.vecTradeoff[index], PStruct.vecK[index],
00078                         Ext.c_str());
00079             else
00080                 sprintf(PStruct.vecNames[index], "%s_l%d_c%4.6f%s",
00081                         FileName.c_str(), PStruct.vecLoss[index],
00082                         PStruct.vecTradeoff[index], Ext.c_str());
00083             index++;
00084         }
00085 
00086     }
00087     return PStruct;
00088 }
00089 
00090 bool ReadModelFile(ifstream& ifsm, vector<float>& SModel) {
00091     static const size_t c_iBuffer = 1024;
00092     char acBuffer[c_iBuffer];
00093     char* nameBuffer;
00094     vector<string> vecstrTokens;
00095     size_t extPlace;
00096     string Ext, FileName;
00097     size_t index = 0;
00098     
00099     while (!ifsm.eof()) {
00100         ifsm.getline(acBuffer, c_iBuffer - 1);
00101         acBuffer[c_iBuffer - 1] = 0;
00102         vecstrTokens.clear();
00103         CMeta::Tokenize(acBuffer, vecstrTokens);
00104         if (vecstrTokens.empty())
00105             continue;
00106         if (vecstrTokens.size() > 1) {
00107             cerr << "Illegal model line (" << vecstrTokens.size() << "): "
00108                     << acBuffer << endl;
00109             continue;
00110         }
00111         if (acBuffer[0] == '#') {
00112             cerr << "skipping " << acBuffer << endl;
00113         } else {
00114           SModel.push_back(atof(vecstrTokens[0].c_str()));
00115         }
00116         
00117         
00118     }
00119     ifsm.close();
00120 }
00121 
00122 // Read in the 
00123 bool ReadProbParamFile(char* prob_file, float& A, float& B) {
00124     static const size_t c_iBuffer = 1024;
00125     char acBuffer[c_iBuffer];
00126     char* nameBuffer;
00127     vector<string> vecstrTokens;
00128     size_t i, extPlace;
00129     string Ext, FileName;
00130     size_t index = 0;
00131     ifstream ifsm;
00132     
00133     ifsm.open( prob_file );
00134     i = 0;
00135     while (!ifsm.eof()) {
00136         ifsm.getline(acBuffer, c_iBuffer - 1);
00137         acBuffer[c_iBuffer - 1] = 0;
00138         vecstrTokens.clear();
00139         CMeta::Tokenize(acBuffer, vecstrTokens);
00140         if (vecstrTokens.empty())
00141             continue;
00142         if (vecstrTokens.size() > 1) {
00143             cerr << "Illegal model line (" << vecstrTokens.size() << "): "
00144                     << acBuffer << endl;
00145             continue;
00146         }
00147         if (acBuffer[0] == '#') {
00148           cerr << "skipping " << acBuffer << endl;
00149         } else {
00150           if( i == 0 )        
00151             A = atof(vecstrTokens[0].c_str());
00152           else if( i == 1 )
00153             B = atof(vecstrTokens[0].c_str());
00154           else{
00155             cerr << "" << endl;
00156             return false;
00157           }       
00158           i++;
00159         }       
00160     }
00161     cerr << "Reading Prob file, A: " << A << ", B: " << B << endl;
00162     return true;
00163 }
00164 
00165 // Platt's binary SVM Probablistic Output
00166 // Assume dec_values and labels have same dimensions and genes
00167 static void sigmoid_train(CDat& dec_values, 
00168               CDat& labels, 
00169               float& A, float& B){
00170     double prior1=0, prior0 = 0;
00171     size_t i, j, idx, k;
00172     float d, lab;
00173     
00174     int max_iter=100;   // Maximal number of iterations
00175     double min_step=1e-10;  // Minimal step taken in line search
00176     double sigma=1e-12; // For numerically strict PD of Hessian
00177     double eps=1e-5;
00178     vector<double> t;
00179     double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize;
00180     double newA,newB,newf,d1,d2;
00181     int iter; 
00182     
00183     // Negatives are values less than 0
00184     for(i = 0; i < dec_values.GetGenes(); i++)
00185       for(j = (i+1); j < dec_values.GetGenes(); j++)
00186         if (!CMeta::IsNaN(d = dec_values.Get(i, j)) && !CMeta::IsNaN(lab = labels.Get(i, j))  ){
00187           if(lab > 0)
00188         prior1 += 1;
00189           else if(lab < 0)
00190         prior0 += 1;          
00191         }
00192     
00193     // initialize size
00194     t.resize(prior0+prior1);
00195     
00196     // Initial Point and Initial Fun Value
00197     A=0.0; B=log((prior0+1.0)/(prior1+1.0));
00198     double hiTarget=(prior1+1.0)/(prior1+2.0);
00199     double loTarget=1/(prior0+2.0);         
00200     double fval = 0.0;
00201         
00202     for(idx = i = 0; idx < dec_values.GetGenes(); idx++)
00203       for(j = (idx+1); j < dec_values.GetGenes(); j++)
00204         if (!CMeta::IsNaN(d = dec_values.Get(idx, j)) && !CMeta::IsNaN(lab = labels.Get(idx, j))  ){
00205           if (lab > 0 ) t[i]=hiTarget;
00206           else t[i]=loTarget;
00207                   
00208           fApB = d*A+B;
00209           if (fApB>=0)
00210         fval += t[i]*fApB + log(1+exp(-fApB));
00211           else
00212         fval += (t[i] - 1)*fApB +log(1+exp(fApB));      
00213           ++i;
00214         }
00215     
00216     for (iter=0;iter<max_iter;iter++){
00217       // Update Gradient and Hessian (use H' = H + sigma I)
00218       h11=sigma; // numerically ensures strict PD
00219       h22=sigma;
00220       h21=0.0;g1=0.0;g2=0.0;
00221       
00222       for(i = idx = 0; idx < dec_values.GetGenes(); idx++)
00223         for(j = (idx+1); j < dec_values.GetGenes(); j++)
00224           if (!CMeta::IsNaN(d = dec_values.Get(idx, j)) && !CMeta::IsNaN(lab = labels.Get(idx, j))  ){                      
00225         fApB = d*A+B;       
00226         
00227         if (fApB >= 0){
00228           p=exp(-fApB)/(1.0+exp(-fApB));
00229           q=1.0/(1.0+exp(-fApB));
00230         }
00231         else{
00232           p=1.0/(1.0+exp(fApB));
00233           q=exp(fApB)/(1.0+exp(fApB));
00234         }
00235         d2=p*q;
00236         h11+=d*d*d2;
00237         h22+=d2;
00238         h21+=d*d2;
00239         d1=t[i]-p;
00240         g1+=d*d1;
00241         g2+=d1;
00242         
00243         ++i;
00244           }
00245       
00246       // Stopping Criteria
00247       if (fabs(g1)<eps && fabs(g2)<eps)
00248         break;
00249       
00250       // Finding Newton direction: -inv(H') * g
00251       det=h11*h22-h21*h21;
00252       dA=-(h22*g1 - h21 * g2) / det;
00253       dB=-(-h21*g1+ h11 * g2) / det;
00254       gd=g1*dA+g2*dB;
00255       
00256       stepsize = 1;     // Line Search
00257       while (stepsize >= min_step){
00258         newA = A + stepsize * dA;
00259         newB = B + stepsize * dB;
00260         
00261         // New function value
00262         newf = 0.0;
00263         
00264         for(i = idx = 0; idx < dec_values.GetGenes(); idx++)
00265           for(j = (idx+1); j < dec_values.GetGenes(); j++)
00266         if (!CMeta::IsNaN(d = dec_values.Get(idx, j)) && !CMeta::IsNaN(lab = labels.Get(idx, j))  ){                        
00267           fApB = d*newA+newB;
00268           
00269           if (fApB >= 0)
00270             newf += t[i]*fApB + log(1+exp(-fApB));
00271           else
00272             newf += (t[i] - 1)*fApB +log(1+exp(fApB));
00273           
00274           ++i;
00275         }
00276         
00277         // Check sufficient decrease
00278         if (newf<fval+0.0001*stepsize*gd){
00279           A=newA;B=newB;fval=newf;
00280           break;
00281         }
00282         else
00283           stepsize = stepsize / 2.0;
00284       }
00285       
00286       if (stepsize < min_step){
00287         cerr << "Line search fails in two-class probability estimates: " << stepsize << ',' << min_step << endl;
00288         break;
00289       }
00290     }
00291     
00292     if (iter>=max_iter)
00293       cerr << "Reaching maximal iterations in two-class probability estimates" << endl; 
00294 }
00295 
00296 static void sigmoid_predict(CDat& dec_values, float A, float B){
00297   size_t i, j;
00298   float d, fApB;
00299   
00300   for(i = 0; i < dec_values.GetGenes(); i++)
00301     for(j = (i+1); j < dec_values.GetGenes(); j++)
00302       if (!CMeta::IsNaN(d = dec_values.Get(i, j))){                         
00303     fApB = d*A+B;
00304     // 1-p used later; avoid catastrophic cancellation
00305     if (fApB >= 0)
00306       dec_values.Set(i,j, exp(-fApB)/(1.0+exp(-fApB)));
00307     else
00308       dec_values.Set(i,j, 1.0/(1+exp(fApB)));
00309       }
00310 }
00311 
00312 
00313 int main(int iArgs, char** aszArgs) {
00314     gengetopt_args_info sArgs;  
00315     SVMLight::CSVMPERF SVM;
00316     
00317     size_t i, j, iGene, jGene, numpos, numneg, iSVM;
00318     ifstream ifsm;
00319     float d, sample_rate, dval;
00320     int   iRet;
00321     map<string, size_t> mapstriZeros, mapstriDatasets;
00322     vector<string> vecstrDatasets;
00323     vector<bool> mapTgene;
00324     vector<bool> mapCgene;
00325     vector<size_t> mapTgene2fold;
00326     vector<int> tgeneCount;
00327     
00328     DIR* dp;
00329     struct dirent* ep;  
00330     CGenome Genome;
00331         CGenes Genes(Genome);
00332     
00333     CGenome GenomeTwo;
00334         CGenes Context(GenomeTwo);
00335     
00336     CGenome GenomeThree;
00337         CGenes Allgenes(GenomeThree);
00338     
00339     if (cmdline_parser(iArgs, aszArgs, &sArgs)) {
00340         cmdline_parser_print_help();
00341         return 1;
00342     }
00343     
00344     CMeta Meta( sArgs.verbosity_arg, sArgs.random_arg );
00345     
00346     SVM.SetVerbosity(sArgs.verbosity_arg);
00347     SVM.SetLossFunction(sArgs.error_function_arg);
00348     if (sArgs.k_value_arg > 1) {
00349         cerr << "k_value is >1. Setting default 0.5" << endl;
00350         SVM.SetPrecisionFraction(0.5);
00351     } else if (sArgs.k_value_arg <= 0) {
00352         cerr << "k_value is <=0. Setting default 0.5" << endl;
00353         SVM.SetPrecisionFraction(0.5);
00354     } else {
00355         SVM.SetPrecisionFraction(sArgs.k_value_arg);
00356     }
00357 
00358     
00359     if (sArgs.cross_validation_arg < 1){
00360       cerr << "cross_valid is <1. Must be set at least 1" << endl;
00361       return 1;
00362     }
00363     else if(sArgs.cross_validation_arg < 2){
00364       cerr << "cross_valid is set to 1. No cross validation holdouts will be run." << endl;
00365     }
00366     
00367     SVM.SetTradeoff(sArgs.tradeoff_arg);
00368     if (sArgs.slack_flag)
00369         SVM.UseSlackRescaling();
00370     else
00371         SVM.UseMarginRescaling();
00372     
00373     
00374     if (!SVM.parms_check()) {
00375         cerr << "Sanity check failed, see above errors" << endl;
00376         return 1;
00377     }
00378     
00379     // read in the list of datasets
00380     if(sArgs.directory_arg ) {
00381       dp = opendir (sArgs.directory_arg);
00382       if (dp != NULL){
00383         while (ep = readdir (dp)){
00384           // skip . .. files and temp files with ~
00385           if (ep->d_name[0] == '.' || ep->d_name[strlen(ep->d_name)-1] == '~') 
00386         continue;
00387           
00388           // currently opens all files. Add filter here if want pick file extensions
00389           vecstrDatasets.push_back((string)sArgs.directory_arg + "/" + ep->d_name);       
00390         }
00391         (void) closedir (dp);       
00392                 
00393         cerr << "Input Dir contrains # datasets: " << vecstrDatasets.size() << '\n';
00394         // sort datasets in alphabetical order
00395         std::sort(vecstrDatasets.begin(), vecstrDatasets.end());
00396       }
00397       else{
00398         cerr << "Couldn't open the directory: " << sArgs.directory_arg << '\n';
00399         return 1;
00400       }   
00401     }
00402     
00403     // read target gene list
00404     if(sArgs.tgene_given ) {
00405       ifstream ifsm;
00406       ifsm.open(sArgs.tgene_arg);
00407       
00408       if (!Genes.Open(ifsm)) {
00409         cerr << "Could not open: " << sArgs.tgene_arg << endl;
00410         return 1;
00411       }
00412       ifsm.close();
00413     }
00414     
00415     // read context gene list
00416     if(sArgs.context_given ) {
00417       ifstream ifsm;
00418       ifsm.open(sArgs.context_arg);
00419       
00420       if (!Context.Open(ifsm)) {
00421         cerr << "Could not open: " << sArgs.context_arg << endl;
00422         return 1;
00423       }
00424       ifsm.close();
00425     }
00426         
00427     // read all gene list
00428     // IF given this flag predict for all gene pairs
00429     if(sArgs.allgenes_given ) {
00430       ifstream ifsm;
00431       ifsm.open(sArgs.allgenes_arg);
00432       
00433       if (!Allgenes.Open(ifsm)) {
00434         cerr << "Could not open: " << sArgs.allgenes_arg << endl;
00435         return 1;
00436       }
00437       ifsm.close();
00438     }
00439     
00441     // Chris added
00442     vector<SVMLight::SVMLabelPair*> vecLabels;
00443     CDat Labels;
00444     CDat Results;
00445     
00446 
00450     if ( sArgs.labels_given ) {   
00451       if (!Labels.Open(sArgs.labels_arg, sArgs.mmap_flag)) {
00452         cerr << "Could not open input labels Dat" << endl;
00453         return 1;
00454       }
00455       
00456       // random sample labels
00457       if( sArgs.subsample_given ){
00458         cerr << "Sub-sample labels to rate:" << sArgs.subsample_arg << endl;
00459         for( i = 0; i < Labels.GetGenes( ); ++i )
00460           for( j = ( i + 1 ); j < Labels.GetGenes( ); ++j )
00461         if( !CMeta::IsNaN( Labels.Get( i, j ) ) &&
00462             ( ( (float)rand( ) / RAND_MAX ) > sArgs.subsample_arg ) )
00463           Labels.Set( i, j, CMeta::GetNaN( ) );             
00464       }   
00465       
00466       // set all NaN values to negatives
00467       if( sArgs.nan2neg_given ){
00468         cerr << "Set NaN labels dat as negatives" << endl;
00469         
00470         for(i = 0; i < Labels.GetGenes(); i++)
00471           for(j = (i+1); j < Labels.GetGenes(); j++)
00472         if (CMeta::IsNaN(d = Labels.Get(i, j)))  
00473           Labels.Set(i, j, -1);
00474       }
00475       
00476       if( sArgs.tgene_given ){
00477         mapTgene.resize(Labels.GetGenes());
00478         
00479         for(i = 0; i < Labels.GetGenes(); i++){
00480           if(Genes.GetGene(Labels.GetGene(i)) == -1)
00481         mapTgene[i] = false;
00482           else
00483         mapTgene[i] = true;
00484         }
00485         
00486         // keep track of positive gene counts
00487         tgeneCount.resize(Labels.GetGenes());
00488                 
00489         // if given a target gene file
00490         // Only keep eges that have only one gene in this targe gene list
00491         if( sArgs.onetgene_flag ){
00492           cerr << "Filtering to only include edges with one gene in gene file: " << sArgs.tgene_arg << endl;
00493           
00494           for(i = 0; i < Labels.GetGenes(); i++)
00495         for(j = (i+1); j < Labels.GetGenes(); j++)
00496           if (!CMeta::IsNaN(d = Labels.Get(i, j))){
00497             if(mapTgene[i] && mapTgene[j])
00498               Labels.Set(i, j, CMeta::GetNaN());
00499             else if(!mapTgene[i] && !mapTgene[j])
00500               Labels.Set(i, j, CMeta::GetNaN());
00501           }
00502         }
00503       } // if edgeholdout flag not given, we are doing gene holdout by default. 
00504       // Since target gene list was not given we are using all genes in labels as target genes to cross holdout
00505       else if( !sArgs.edgeholdout_flag ){
00506         mapTgene.resize(Labels.GetGenes());
00507         
00508         // all genes are target genes
00509         for(i = 0; i < Labels.GetGenes(); i++){
00510           mapTgene[i] = true;
00511         }
00512         
00513         // keep track of positive gene counts
00514         tgeneCount.resize(Labels.GetGenes());
00515       }
00516       
00517       //if given a context map the context genes
00518       if( sArgs.context_given ){
00519         mapCgene.resize(Labels.GetGenes());
00520         
00521         for(i = 0; i < Labels.GetGenes(); i++){
00522           if(Context.GetGene(Labels.GetGene(i)) == -1)
00523         mapCgene[i] = false;
00524           else
00525         mapCgene[i] = true;
00526         }
00527       }
00528       
00529       // Set target prior
00530       if(sArgs.prior_given){
00531         numpos = 0;
00532         numneg = 0;
00533         for(i = 0; i < Labels.GetGenes(); i++)
00534           for(j = (i+1); j < Labels.GetGenes(); j++)
00535         if (!CMeta::IsNaN(d = Labels.Get(i, j))){
00536           if(d > 0){
00537             ++numpos;}
00538           else if(d < 0){
00539             ++numneg;
00540           }
00541         }
00542         
00543         if( ((float)numpos / (numpos + numneg)) < sArgs.prior_arg){
00544           
00545           cerr << "Convert prior from orig: " << ((float)numpos / (numpos + numneg)) << " to target: " << sArgs.prior_arg << endl;
00546           
00547           sample_rate = ((float)numpos / (numpos + numneg)) / sArgs.prior_arg;
00548           
00549           // remove neg labels to reach prior
00550           for(i = 0; i < Labels.GetGenes(); i++)
00551         for(j = (i+1); j < Labels.GetGenes(); j++)
00552           if (!CMeta::IsNaN(d = Labels.Get(i, j)) && d < 0){
00553             if((float)rand() / RAND_MAX  > sample_rate)
00554               Labels.Set(i, j, CMeta::GetNaN());
00555           }
00556         }
00557       }
00558       
00559       // output sample labels for eval/debug purpose
00560       if(sArgs.OutLabels_given){
00561         cerr << "save sampled labels as (1,0)s to: " << sArgs.OutLabels_arg << endl;
00562         Labels.Normalize( CDat::ENormalizeMinMax );
00563         Labels.Save(sArgs.OutLabels_arg);
00564         return 0;
00565       }
00566       
00567       // Exclude labels without context genes
00568       if(sArgs.context_given )
00569         Labels.FilterGenes( Context, CDat::EFilterInclude );
00570       
00571       
00572       // If not given a SVM model/models we are in learning mode, thus construct each SVMLabel object for label
00573       if( !sArgs.model_given && !sArgs.modelPrefix_given ){
00574         numpos = 0;
00575         for(i = 0; i < Labels.GetGenes(); i++)
00576           for(j = (i+1); j < Labels.GetGenes(); j++)
00577         if (!CMeta::IsNaN(d = Labels.Get(i, j))){
00578           if (d != 0)  
00579             vecLabels.push_back(new SVMLight::SVMLabelPair(d, i, j));
00580           if(d > 0)
00581             ++numpos;
00582         }
00583         
00584         // check to see if you have enough positives to learn from
00585         if(sArgs.mintrain_given && sArgs.mintrain_arg > numpos){
00586           cerr << "Not enough positive examples from: " << sArgs.labels_arg << " numpos: " << numpos << endl;
00587           return 1;
00588         }
00589       }
00590       
00591       // save sampled labels 
00592       if(sArgs.SampledLabels_given) {
00593         cerr << "Save sampled labels file: " << sArgs.SampledLabels_arg << endl;
00594         Labels.Save(sArgs.SampledLabels_arg);       
00595       }   
00596     } 
00597     
00598     SVMLight::SAMPLE* pTrainSample;
00599     SVMLight::SAMPLE* pAllSample;
00600     vector<SVMLight::SVMLabelPair*> pTrainVector[sArgs.cross_validation_arg];
00601     vector<SVMLight::SVMLabelPair*> pTestVector[sArgs.cross_validation_arg];
00602     
00603     // ###################################
00604     //
00605     // Now conduct Learning if given a labels and no SVM models 
00606     //
00607     if (sArgs.output_given && sArgs.labels_given && !sArgs.modelPrefix_given && !sArgs.model_given) {
00608       //do learning and classifying with cross validation
00609       if( sArgs.cross_validation_arg > 1){
00610         cerr << "setting cross validation holds" << endl;
00611         
00612         mapTgene2fold.resize(mapTgene.size());
00613         
00614         // assign target genes to there cross validation fold
00615         if(sArgs.tgene_given || !sArgs.edgeholdout_flag){
00616           for(i = 0; i < mapTgene.size(); i++){
00617         if(!mapTgene[i]){
00618           mapTgene2fold[i] = -1; 
00619           continue;
00620         }
00621         //cerr << "what's up?" << endl;
00622         mapTgene2fold[i] = rand() % sArgs.cross_validation_arg;
00623           }
00624           
00625           // cross-fold by target gene
00626           for (i = 0; i < sArgs.cross_validation_arg; i++) {
00627         cerr << "cross validation holds setup:" << i << endl;
00628         
00629         // keep track of positive gene counts
00630         if(sArgs.balance_flag){
00631           cerr << "Set up balance: " << i << endl;
00632           for(j = 0; j < Labels.GetGenes(); j++)
00633             tgeneCount[j] = 0;
00634           
00635           for(j = 0; j < vecLabels.size(); j++)
00636             if(vecLabels[j]->Target > 0){
00637               ++(tgeneCount[vecLabels[j]->iidx]);
00638               ++(tgeneCount[vecLabels[j]->jidx]);
00639             }
00640           
00641           if(sArgs.bfactor_given)
00642             for(j = 0; j < vecLabels.size(); j++)
00643               if(tgeneCount[vecLabels[j]->jidx] < 500)
00644             tgeneCount[vecLabels[j]->jidx] = sArgs.bfactor_arg*tgeneCount[vecLabels[j]->jidx];
00645         }
00646         
00647         for (j = 0; j < vecLabels.size(); j++) {
00648           //if( j % 1000 == 0)
00649           //cerr << "cross validation push labels:" << j << endl;
00650           
00651           // assume only one gene is a target gene in a edge
00652           if(mapTgene[vecLabels[j]->iidx]){
00653             if(vecLabels[j]->Target < 0){
00654               --(tgeneCount[vecLabels[j]->iidx]);
00655             }
00656             
00657             if(mapTgene2fold[vecLabels[j]->iidx] == i)              
00658               pTestVector[i].push_back(vecLabels[j]);
00659             else{
00660               //cerr << tgeneCount[vecLabels[j]->iidx] << endl;
00661               
00662               if( sArgs.balance_flag && vecLabels[j]->Target < 0 && tgeneCount[vecLabels[j]->iidx] < 0){
00663             continue;
00664               }
00665               
00666               // only add if both genes are in context
00667               if( sArgs.context_given  && ( !mapCgene[vecLabels[j]->iidx] || !mapCgene[vecLabels[j]->jidx]))
00668             continue;
00669               
00670               pTrainVector[i].push_back(vecLabels[j]); 
00671             }
00672           }
00673           else if(mapTgene[vecLabels[j]->jidx]){
00674             if(vecLabels[j]->Target < 0)
00675               --(tgeneCount[vecLabels[j]->jidx]);
00676             
00677             if(mapTgene2fold[vecLabels[j]->jidx] == i)
00678               pTestVector[i].push_back(vecLabels[j]);
00679             else{
00680               //cerr << tgeneCount[vecLabels[j]->jidx] << endl;
00681               
00682               if( sArgs.balance_flag && vecLabels[j]->Target < 0 && tgeneCount[vecLabels[j]->jidx] < 0){
00683             continue;
00684               }
00685               
00686               // only add if both genes are in context
00687               if( sArgs.context_given && ( !mapCgene[vecLabels[j]->iidx] || !mapCgene[vecLabels[j]->jidx]))
00688             continue;
00689               
00690               pTrainVector[i].push_back(vecLabels[j]); 
00691             }
00692           }
00693           else{
00694             cerr << "Error: edge exist without a target gene" << endl; return 1;
00695           }
00696         }
00697         
00698         cerr << "test,"<< i <<": " << pTestVector[i].size() << endl;
00699         int numpos = 0;
00700         for(j=0; j < pTrainVector[i].size(); j++)
00701           if(pTrainVector[i][j]->Target > 0)
00702             ++numpos;
00703         
00704         if( numpos < 1 || (sArgs.mintrain_given && sArgs.mintrain_arg > numpos) ){                      
00705           cerr << "Not enough positive examples from fold: " << i  << " file: " << sArgs.labels_arg << " numpos: " << numpos << endl;
00706           return 1;
00707         }
00708         
00709         cerr << "train,"<< i <<","<<numpos<<": " << pTrainVector[i].size() << endl;
00710         
00711           }
00712         }
00713         else{ //randomly set eges into cross-fold
00714           if( sArgs.context_given ){
00715         cerr << "context not implemented yet for random edge holdout" << endl;
00716         return 1;
00717           }
00718           
00719           for (i = 0; i < sArgs.cross_validation_arg; i++) {
00720         pTestVector[i].reserve((size_t) vecLabels.size()
00721                        / sArgs.cross_validation_arg + sArgs.cross_validation_arg);
00722         pTrainVector[i].reserve((size_t) vecLabels.size()
00723                     / (sArgs.cross_validation_arg)
00724                     * (sArgs.cross_validation_arg - 1)
00725                     + sArgs.cross_validation_arg);
00726         for (j = 0; j < vecLabels.size(); j++) {
00727           if (j % sArgs.cross_validation_arg == i) {
00728             pTestVector[i].push_back(vecLabels[j]);
00729           } else {
00730             pTrainVector[i].push_back((vecLabels[j]));
00731           }
00732         }
00733           }
00734         }
00735       }
00736       else{ // if you have less than 2 fold cross, no cross validation is done, all train genes are used and predicted
00737         
00738         // no holdout so train is the same as test gene set
00739         pTestVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00740         pTrainVector[0].reserve((size_t) vecLabels.size() + sArgs.cross_validation_arg);
00741         
00742         for (j = 0; j < vecLabels.size(); j++) {
00743           pTestVector[0].push_back(vecLabels[j]);             
00744           pTrainVector[0].push_back(vecLabels[j]);          
00745         }
00746       }
00747       
00748       // initalize the results
00749       Results.Open(Labels.GetGeneNames(), true);
00750       
00752       // Create feature vectors for all Label pairs using input datasets
00753       //
00754       cerr << "CreateDocs!"<< endl;
00755       if(sArgs.normalizeZero_flag){
00756         SVMLight::CSVMPERF::CreateDoc(vecstrDatasets,
00757                       vecLabels,
00758                       Labels.GetGeneNames(),
00759                       Sleipnir::CDat::ENormalizeMinMax);
00760       }else if(sArgs.normalizeNPone_flag){
00761         SVMLight::CSVMPERF::CreateDoc(vecstrDatasets,
00762                       vecLabels,
00763                       Labels.GetGeneNames(),
00764                       Sleipnir::CDat::ENormalizeMinMaxNPone);
00765       }else{
00766         SVMLight::CSVMPERF::CreateDoc(vecstrDatasets,
00767                       vecLabels,
00768                       Labels.GetGeneNames());
00769       }
00770       
00772       // Start learning for each cross validation fold
00773       for (i = 0; i < sArgs.cross_validation_arg; i++) {
00774         std::stringstream sstm;
00775         
00776         // build up the output SVM model file name
00777         if(sArgs.context_given){
00778           std::string path(sArgs.context_arg);
00779           size_t pos = path.find_last_of("/");
00780           std::string cname;
00781           if(pos != std::string::npos)
00782         cname.assign(path.begin() + pos + 1, path.end());
00783           else
00784         cname = path;
00785           
00786           sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << cname << "." << i << ".svm";           
00787         }else
00788           sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << i << ".svm";
00789         
00790         cerr << "Cross validation fold: " << i << endl;             
00791         pTrainSample = SVMLight::CSVMPERF::CreateSample(pTrainVector[i]);
00792         
00793         // Skip learning if SVM model file already exist
00794         if( sArgs.skipSVM_flag && 
00795         access((char*)(sstm.str().c_str()), R_OK ) != -1
00796         ){
00797           //SVM.ReadModel((char*)(sstm.str().c_str()));
00798           SVM.ReadModelLinear((char*)(sstm.str().c_str()));
00799           cerr << "Using existing trained SVM model: " << sstm.str() << endl;
00800         }else{
00801           SVM.Learn(*pTrainSample);
00802           cerr << "SVM model Learned" << endl;
00803         }
00804         
00805         SVM.Classify(Results,
00806              pTestVector[i]);
00807         
00808         cerr << "SVM model classified holdout" << endl;
00809         
00810         if( sArgs.savemodel_flag ){
00811           SVM.WriteModel((char*)(sstm.str().c_str()));
00812         }
00813         
00814         // DEBUG
00815         SVMLight::CSVMPERF::FreeSample_leave_Doc(*pTrainSample);
00816         free(pTrainSample);
00817       }
00818       
00819       if( sArgs.prob_flag ){
00820         cerr << "Converting prediction values to estimated probablity" << endl;
00821         float A, B;
00822         
00823         // TODO add function to read in prob parameter file if already existing
00824         sigmoid_train(Results, Labels, A, B);
00825         sigmoid_predict(Results, A, B);
00826       }
00827       else if( sArgs.probCross_flag ){
00828         float A, B;
00829         size_t k, ctrain, itrain;
00830         vector<SVMLight::SVMLabelPair*> probTrainVector;
00831         
00832         for (i = 0; i < sArgs.cross_validation_arg; i++) {            
00833           cerr << "Convert to probability for cross fold: " << i << endl;             
00834           
00835           // construct prob file name
00836           std::stringstream pstm;       
00837           ofstream ofsm;              
00838           if(sArgs.context_given){
00839         std::string path(sArgs.context_arg);
00840         size_t pos = path.find_last_of("/");
00841         std::string cname;
00842         if(pos != std::string::npos)
00843           cname.assign(path.begin() + pos + 1, path.end());
00844         else
00845           cname = path;
00846         
00847         pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << cname  << "." << i << ".svm.prob";           
00848           }else
00849         pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << i << ".svm.prob";
00850           
00851           
00852           if( sArgs.skipSVM_flag && 
00853           access((char*)(pstm.str().c_str()), R_OK ) != -1
00854           ){
00855         
00856         // read in parameter file
00857         if(!ReadProbParamFile((char*)(pstm.str().c_str()), A, B)){
00858           cerr << "Failed to read Probablity parameter file: " << pstm.str() << endl;
00859           return 1;
00860         }
00861         cerr << pstm.str() << ": read in values, A: " << A << ", B: " << B << endl;     
00862         
00863           }else{        
00864         ctrain = 0;
00865         for (j = 0; j < sArgs.cross_validation_arg; j++) {            
00866           if(i == j)
00867             continue;
00868           ctrain += pTrainVector[j].size();
00869         }
00870         
00871         probTrainVector.resize(ctrain);
00872         itrain = 0;
00873         for (j = 0; j < sArgs.cross_validation_arg; j++) {            
00874           if(i == j)
00875             continue;
00876           for(k = 0; k < pTrainVector[j].size(); k++){
00877             probTrainVector[itrain] = pTrainVector[j][k];
00878             itrain += 1;
00879           }
00880         }
00881         
00882         // train A,B sigmoid perameters
00883         SVM.sigmoid_train(Results, probTrainVector, A, B);      
00884           }
00885           
00886           SVM.sigmoid_predict(Results, pTestVector[i], A, B);
00887           
00888           // open prob param file
00889           if(sArgs.savemodel_flag){
00890         ofsm.open(pstm.str().c_str());
00891         ofsm << A << endl;
00892         ofsm << B << endl;
00893         ofsm.close();
00894           }
00895         }
00896       }
00897       
00898       // only save cross-validated results when not predicting all genes
00899       if( !sArgs.allgenes_given )
00900         Results.Save(sArgs.output_arg);
00901     }
00902     
00904     // If given all genes arg, this puts in prediction mode for all gene pairs from the gene list
00905     //
00906     if ( sArgs.allgenes_given && sArgs.output_given) { //read model and classify all      
00907       size_t iData;
00908       vector<vector<float> > vecSModel;
00909       vecSModel.resize(sArgs.cross_validation_arg);
00910       ifstream mifsm;
00911       vector<CDat* > vecResults;
00912       vector<size_t> veciGene;
00913       
00914       cerr << "Predicting for all genes given." << endl;
00915       
00916       // open SVM models for prefix file
00917       if(sArgs.modelPrefix_given){
00918         for(i = 0; i < sArgs.cross_validation_arg; i++){
00919           std::stringstream sstm;
00920           
00921           sstm << sArgs.modelPrefix_arg << "." << i << ".svm";        
00922           if( access((char*)(sstm.str().c_str()), R_OK) == -1 ){
00923         cerr << "ERROR: SVM model file cannot be opned: " << sstm.str() << endl;
00924         return 1;
00925           }
00926           
00927           mifsm.open((char*)(sstm.str().c_str()));                    
00928           ReadModelFile(mifsm, vecSModel[i]);
00929         }
00930       }else if( sArgs.model_given ){ // open SVM model from file
00931         //vector<float> SModel;
00932         vecSModel.resize(1);
00933         
00934         if( access(sArgs.model_arg, R_OK) == -1 ){
00935           cerr << "ERROR: SVM model file cannot be opned: " << sArgs.model_arg << endl;
00936           return 1;
00937         }
00938         
00939         mifsm.open(sArgs.model_arg);
00940         ReadModelFile(mifsm, vecSModel[0]);
00941         
00942         // DEBUG check if this is ok
00943         sArgs.cross_validation_arg = 1;
00944       }else{ 
00945         // open SVM model file from which was just trained
00946         for(i = 0; i < sArgs.cross_validation_arg; i++){
00947           std::stringstream sstm;
00948           
00949           if(sArgs.context_given){
00950         std::string path(sArgs.context_arg);
00951         size_t pos = path.find_last_of("/");
00952         std::string cname;
00953         if(pos != std::string::npos)
00954           cname.assign(path.begin() + pos + 1, path.end());
00955         else
00956           cname = path;
00957         
00958         sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << cname << "." << i << ".svm";             
00959           }else
00960         sstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << i << ".svm";
00961           
00962           
00963           if( access((char*)(sstm.str().c_str()), R_OK) == -1 ){
00964         cerr << "ERROR: SVM model file cannot be opned: " << sstm.str() << endl;
00965         return 1;
00966           }
00967           
00968           mifsm.open((char*)(sstm.str().c_str()));        
00969           ReadModelFile(mifsm, vecSModel[i]);
00970         }
00971       }
00972       
00973       // Initialize output for input all gene list
00974       vecResults.resize(sArgs.cross_validation_arg);
00975       for(i = 0; i < sArgs.cross_validation_arg; i++){
00976         vecResults[ i ] = new CDat();
00977         vecResults[ i ]->Open(Allgenes.GetGeneNames( ), true);
00978       }
00979       
00980           
00981       // Now iterate over all datasets to make predictions for all gene pairs
00982       CDat wDat;            
00983       for(iData = 0; iData < vecstrDatasets.size(); iData++){       
00984         if(!wDat.Open(vecstrDatasets[iData].c_str(), sArgs.mmap_flag)) {
00985           cerr << vecstrDatasets[iData].c_str() << endl;
00986           cerr << "Could not open: " << vecstrDatasets[iData] << endl;
00987           return 1;
00988         }
00989         
00990         cerr << "Open: " << vecstrDatasets[iData] << endl;
00991         
00992         // normalize data file
00993         if(sArgs.normalizeZero_flag){
00994           cerr << "Normalize input [0,1] data" << endl;
00995           wDat.Normalize( Sleipnir::CDat::ENormalizeMinMax );
00996         }else if(sArgs.normalizeNPone_flag){
00997           cerr << "Normalize input [-1,1] data" << endl;
00998           wDat.Normalize( Sleipnir::CDat::ENormalizeMinMaxNPone );
00999         }
01000         
01001         // map result gene list to dataset gene list
01002         veciGene.resize( vecResults[ 0 ]->GetGenes() );     
01003         for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){
01004           veciGene[ i ] = wDat.GetGene( vecResults[ 0 ]->GetGene( i ) );
01005         }
01006         
01007         // compute prediction component for this dataset/SVM model
01008         for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){
01009           iGene = veciGene[i];
01010           
01011           for(j = i+1; j < vecResults[ 0 ]->GetGenes(); j++){
01012         jGene = veciGene[j];
01013         
01014         // if no value, set feature to 0
01015         if( iGene == -1 || jGene == -1 || CMeta::IsNaN(d = wDat.Get(iGene, jGene) ) )
01016           d = 0;
01017         
01018         // iterate each SVM model
01019         for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){
01020           if( CMeta::IsNaN(dval = vecResults[ iSVM ]->Get(i, j)) )
01021             vecResults[ iSVM ]->Set(i, j, (d * vecSModel[iSVM][iData])  );
01022           else
01023             vecResults[ iSVM ]->Set(i, j, (dval + (d * vecSModel[iSVM][iData])) );
01024           
01025         }
01026         
01027           }
01028         }               
01029       }
01030       
01031       
01032       // convert the SVM predictions for each model to Probablity if required
01033       if( sArgs.prob_flag || sArgs.probCross_flag ){
01034         // iterate over each SVM model and its output and convert to probablity
01035         float A;
01036         float B;
01037         
01038         cerr << "convert prediction dabs to probablity" << endl;
01039         
01040         for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){
01041           // Read A,B perameters
01042           std::stringstream pstm;
01043           if(sArgs.modelPrefix_given){
01044         pstm << sArgs.modelPrefix_arg  << "." << iSVM << ".svm.prob";
01045           }else if( sArgs.model_given ){ // open SVM model from file
01046         
01047         // DEBUG, this might need to looked at!!!
01048         pstm << sArgs.model_arg << ".prob";
01049         
01050           }else{ // open SVM model file from which was just trained       
01051         if(sArgs.context_given){
01052           std::string path(sArgs.context_arg);
01053           size_t pos = path.find_last_of("/");
01054           std::string cname;
01055           if(pos != std::string::npos)
01056             cname.assign(path.begin() + pos + 1, path.end());
01057           else
01058             cname = path;
01059           
01060           pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << cname  << "." << iSVM << ".svm.prob";              
01061         }else
01062           pstm << sArgs.output_arg << "." << sArgs.tradeoff_arg  << "." << iSVM << ".svm.prob";
01063           }
01064           
01065           // read in parameter file
01066           if(!ReadProbParamFile((char*)(pstm.str().c_str()), A, B)){
01067         cerr << "Failed to read Probablity parameter file: " << pstm.str() << endl;
01068         return 1;
01069           }
01070           
01071           cerr << pstm.str() << ", A: " << A << ",B: " << B << endl;
01072       
01073           // debug
01074           //cerr << "known:" << vecResults[ iSVM ]->Get(x,y) << endl;
01075           
01076           // now convert the SVM model prediction values to probablity based on param A,B
01077           sigmoid_predict( *(vecResults[ iSVM ]), A, B);
01078           
01079           //cerr << "known PROB:" << vecResults[ iSVM ]->Get(x,y) << endl;
01080         }
01081       }
01082       
01083       // filter results
01084       if( sArgs.tgene_given ){
01085         for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){
01086           vecResults[ iSVM ]->FilterGenes( sArgs.tgene_arg, CDat::EFilterEdge );
01087         }
01088       }
01089       
01090       // Exclude pairs without context genes
01091       if(sArgs.context_given ){
01092         for(iSVM = 0; iSVM < sArgs.cross_validation_arg; iSVM++){
01093           vecResults[ iSVM ]->FilterGenes( Context, CDat::EFilterInclude );
01094         }
01095       }
01096       
01097       // Take average of prediction value from each model
01098       // The final results will be stored/overwritten into the first dab (i.e. vecResults[ 0 ])
01099       for(i = 0; i < vecResults[ 0 ]->GetGenes(); i ++)
01100         for(j = i+1; j < vecResults[ 0 ]->GetGenes(); j++){
01101           if (CMeta::IsNaN(dval = vecResults[ 0 ]->Get(i, j)))
01102         continue;
01103           
01104           // start from the second
01105           // Assume all SVM model prediction dabs have identical NaN locations
01106           if(sArgs.aggregateMax_flag){ // take the maximum prediction value
01107         float maxval = dval;
01108         for(iSVM = 1; iSVM < sArgs.cross_validation_arg; iSVM++)
01109           if( vecResults[ iSVM ]->Get(i, j) > dval )
01110             dval = vecResults[ iSVM ]->Get(i, j);
01111         
01112         vecResults[ 0 ]->Set(i, j, dval);
01113           }else{ // Average over prediction values
01114         for(iSVM = 1; iSVM < sArgs.cross_validation_arg; iSVM++)
01115           dval += vecResults[ iSVM ]->Get(i, j);          
01116         
01117         vecResults[ 0 ]->Set(i, j, (dval / sArgs.cross_validation_arg) );
01118           }
01119         }
01120       
01121       // Replace gene-pair prediction values with labels from the cross-validation result
01122       // This basically replaces the prediction value with one prediction value with which this label was heldout
01123       // Only do this if cross-validation was conducted or given a replacement dab
01124       if( (!sArgs.NoCrossPredict_flag && sArgs.output_given && sArgs.labels_given && !sArgs.modelPrefix_given && !sArgs.model_given) || 
01125           sArgs.CrossResult_given ){
01126         
01127         if( sArgs.CrossResult_given ){        
01128           if(!Results.Open(sArgs.CrossResult_arg, sArgs.mmap_flag)) {
01129         cerr << "Could not open: " << sArgs.CrossResult_arg << endl;
01130         return 1;
01131           }
01132         }
01133         
01134         cerr << "Label pairs set to cross-validation values" << endl;
01135         
01136         // map result gene list to dataset gene list
01137         veciGene.resize( vecResults[ 0 ]->GetGenes() );     
01138         for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){
01139           veciGene[ i ] = Results.GetGene( vecResults[ 0 ]->GetGene( i ) );
01140         }
01141         
01142         // compute prediction component for this dataset/SVM model
01143         for(i = 0; i < vecResults[ 0 ]->GetGenes(); i++){
01144           if( (iGene = veciGene[i]) == -1 )
01145         continue;
01146           
01147           for(j = i+1; j < vecResults[ 0 ]->GetGenes(); j++){
01148         if( (jGene = veciGene[j]) == -1 )
01149           continue;
01150         
01151         if( CMeta::IsNaN(d = Results.Get(iGene, jGene) ) )
01152           continue;
01153         
01154         vecResults[ 0 ]->Set(i, j, d);
01155           }
01156         }
01157       }
01158       
01159       // now save the averged prediction values
01160       vecResults[ 0 ]->Save(sArgs.output_arg);
01161     }
01162     
01163 }
01164