Sleipnir
src/svmperf.h
00001 /*****************************************************************************
00002  * This file is provided under the Creative Commons Attribution 3.0 license.
00003  *
00004  * You are free to share, copy, distribute, transmit, or adapt this work
00005  * PROVIDED THAT you attribute the work to the authors listed below.
00006  * For more information, please see the following web page:
00007  * http://creativecommons.org/licenses/by/3.0/
00008  *
00009  * This file is a component of the Sleipnir library for functional genomics,
00010  * authored by:
00011  * Curtis Huttenhower (chuttenh@princeton.edu)
00012  * Mark Schroeder
00013  * Maria D. Chikina
00014  * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00015  *
00016  * If you use this library, the included executable tools, or any related
00017  * code in your work, please cite the following publication:
00018  * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00019  * Olga G. Troyanskaya.
00020  * "The Sleipnir library for computational functional genomics"
00021  *****************************************************************************/
00022 
00023 #ifndef NO_SVM_PERF
00024 #ifndef SVMPERFI_H
00025 #define SVMPERFI_H
00026 #include "pclset.h"
00027 #include "meta.h"
00028 #include "dat.h"
00029 
00030 #include <stdio.h>
00031 
00032 /* removed to support cygwin */
00033 //#include <execinfo.h>
00034 
00035 namespace SVMLight {
00036 extern "C" {
00037 
00038 #define class Class
00039 
00040 #include <svm_light/svm_common.h>
00041 #include <svm_light/svm_learn.h>
00042 #include <svm_struct_api_types.h>
00043 #include <svm_struct/svm_struct_common.h>
00044 #include <svm_struct_api.h>
00045 #include <svm_struct/svm_struct_learn.h>
00046 #undef class
00047 //#include "svm_struct_api.h"
00048 
00049 }
00050 
00051 class SVMLabel {
00052 public:
00053     string GeneName;
00054     double Target;
00055     size_t index;
00056     bool hasIndex;
00057     SVMLabel(std::string name, double target) {
00058         GeneName = name;
00059         Target = target;
00060         hasIndex = false;
00061         index = -1;
00062     }
00063 
00064     SVMLabel() {
00065         GeneName = "";
00066         Target = 0;
00067     }
00068     void SetIndex(size_t i) {
00069         index = i;
00070         hasIndex = true;
00071     }
00072 };
00073 
00074 class SVMLabelPair {
00075 public:
00076     double Target;
00077     size_t iidx;
00078     size_t jidx;
00079     bool hasIndex;
00080     DOC* pDoc;
00081     
00082     SVMLabelPair(double target, size_t i, size_t j) {
00083         Target = target;
00084         hasIndex = true;
00085         iidx = i;
00086         jidx = j;
00087     }
00088 
00089     SVMLabelPair() {
00090         Target = 0;
00091         hasIndex = false;
00092     }
00093     
00094     void SetIndex(size_t i, size_t j) {
00095         iidx = i;
00096         jidx = j;
00097         hasIndex = true;
00098     }
00099     
00100     void SetDoc(DOC* inDoc) {
00101       pDoc = inDoc;
00102     }
00103     
00104     DOC* GetDoc() {
00105       return pDoc;
00106     }   
00107     
00108 };
00109 
00110 class Result {
00111 public:
00112     std::string GeneName;
00113     double Target;
00114     double Value;
00115     int CVround;
00116     int Rank;
00117     Result() {
00118         GeneName = "";
00119         Target = 0;
00120         Value = Sleipnir::CMeta::GetNaN();
00121     }
00122 
00123     Result(std::string name, int cv = -1) {
00124         GeneName = name;
00125         Target = 0;
00126         Value = 0;
00127         CVround = cv;
00128         Rank = -1;
00129     }
00130     string toString() {
00131         stringstream ss;
00132         ss << GeneName << '\t' << Target << '\t' << Value << '\t' << "CV"
00133                 << CVround;
00134         if (Rank != -1) {
00135             ss << '\t' << Rank;
00136         }
00137         return ss.str();
00138     }
00139 
00140 };
00141 
00142 enum EFilter {
00143     EFilterInclude = 0, EFilterExclude = EFilterInclude + 1,
00144 };
00145 
00146 //this class encapsulates the model and parameters and has no associated data
00147 
00148 class CSVMPERF {
00149 public:
00150     LEARN_PARM learn_parm;
00151     KERNEL_PARM kernel_parm;
00152     STRUCT_LEARN_PARM struct_parm;
00153     STRUCTMODEL structmodel;
00154     int Alg;
00155     CSVMPERF(int a = 3) {
00156         Alg = a;
00157         initialize();
00158         //set_struct_verbosity(5);
00159     }
00160 
00161     void SetLossFunction(size_t loss_f) {
00162         struct_parm.loss_function = loss_f;
00163     }
00164 
00165     void SetTradeoff(double tradeoff) {
00166         struct_parm.C = tradeoff;
00167 
00168     }
00169     void SetKernel(int K) {
00170         kernel_parm.kernel_type = K;
00171     }
00172     void SetPolyD(int D) {
00173         kernel_parm.poly_degree = D;
00174     }
00175 
00176     void UseCPSP() {
00177         Alg = 9;
00178         struct_parm.preimage_method = 2;
00179         struct_parm.sparse_kernel_size = 500;
00180         struct_parm.bias = 0;
00181     }
00182 
00183     void SetRBFGamma(double g) {
00184         kernel_parm.rbf_gamma = g;
00185         UseCPSP();
00186     }
00187 
00188     void UseSlackRescaling() {
00189         struct_parm.loss_type = SLACK_RESCALING;
00190     }
00191 
00192     void UseMarginRescaling() {
00193         struct_parm.loss_type = MARGIN_RESCALING;
00194     }
00195 
00196     void SetPrecisionFraction(double frac) {
00197         struct_parm.prec_rec_k_frac = frac;
00198     }
00199 
00200     void ReadModel(char* model_file) {
00201         //FreeModel();
00202         structmodel = read_struct_model(model_file, &struct_parm);
00203     }
00204 
00205     void WriteModel(char* model_file, int simple_model_flag = 1) {
00206         if (kernel_parm.kernel_type == LINEAR && simple_model_flag) {
00207             ofstream ofsm;
00208             ofsm.open(model_file);
00209             for (size_t i = 0; i < structmodel.sizePsi; i++) {
00210                 ofsm << structmodel.w[i+1] << endl;
00211             }
00212         } else {
00213             write_struct_model(model_file, &structmodel, &struct_parm);
00214         }
00215     }
00216 
00217     void WriteWeights(ostream& osm) {
00218         osm << structmodel.w[0];
00219         for (size_t i = 1; i < structmodel.sizePsi + 1; i++)
00220             osm << '\t' << structmodel.w[i];
00221         osm << endl;
00222     }
00223 
00224     static void FreePattern(pattern x) {
00225         free_pattern(x);
00226     }
00227 
00228     static void FreeLabel(label y) {
00229         free_label(y);
00230     }
00231 
00232     void FreeModel() {
00233         free_struct_model(structmodel);
00234     }
00235 
00236     static void FreeSample(sample s) {
00237         free_struct_sample(s);
00238     }
00239 
00240     static void FreeDoc(DOC* pDoc) {
00241         free_example(pDoc, true);
00242     }
00243     void SetVerbosity(size_t V);
00244 
00245     //static members process data
00246     //single gene predictions
00247 
00248     //creates a Doc for a given gene index in a microarray set
00249     static DOC* CreateDoc(Sleipnir::CPCLSet &PCLSet, size_t iGene, size_t iDoc);
00250 
00251     //creates a Doc for a given gene index in a single microarray
00252     static DOC* CreateDoc(Sleipnir::CPCL &PCL, size_t iGene, size_t iDoc);
00253 
00254     //creates a Doc for a given gene in a Dat file using all other genes as features
00255     static DOC* CreateDoc(Sleipnir::CDat& Dat, size_t iGene, size_t iDoc);
00256 
00257     //Creates a sample using a PCLset and SVMlabels Looks up genes by name.
00258     static SAMPLE* CreateSample(Sleipnir::CPCLSet &PCLSet,
00259             vector<SVMLabel> SVMLabels);
00260 
00261     //Creates a sample using a single PCL and SVMlabels Looks up genes by name.
00262     static SAMPLE
00263     * CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels);
00264 
00265     //Same as above except creates bootstrap samples and does not duplicate data
00266     static SAMPLE** CreateSampleBootStrap(Sleipnir::CPCL &PCL,
00267             vector<SVMLabel>& SVMLabels, vector<vector<size_t> > vecvecIndex);
00268 
00269     //Creates a sample using a Dat and SVMlabels. Looks up genes by name
00270     static SAMPLE* CreateSample(Sleipnir::CDat& CDat,
00271             vector<SVMLabel> SMVLabels);
00272 
00273     //Classify single genes
00274     vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels);
00275     vector<Result> Classify(Sleipnir::CPCLSet& PCLSet,
00276             vector<SVMLabel> SVMLabels);
00277     vector<Result> Classify(Sleipnir::CDat& Dat, vector<SVMLabel> SVMLabels);
00278 
00279     //MEMBER functions wraps learning
00280     void Learn(SAMPLE &sample) {
00281         cerr << "SLACK NORM =" << struct_parm.slack_norm << endl;
00282         /*  if (kernel_parm.kernel_type==CUSTOM)
00283          svm_learn_struct_joint_custom(sample, &struct_parm, &learn_parm, &kernel_parm, &structmodel);
00284          else*/
00285         //Take care of the labels here
00286         size_t numn, nump;
00287         numn = nump = 0;
00288         size_t i;
00289         for (i = 0; i < sample.examples[0].x.totdoc; i++) {
00290             if (sample.examples[0].y.Class[i] > 0)
00291                 nump++;
00292             else
00293                 numn++;
00294         }
00295         //make scaling appropriate for the loss function being used
00296         if ((struct_parm.loss_function == ZEROONE)
00297                 || (struct_parm.loss_function == FONE)
00298                 || (struct_parm.loss_function == ERRORRATE)
00299                 || (struct_parm.loss_function == PRBEP)
00300                 || (struct_parm.loss_function == PREC_K)
00301                 || (struct_parm.loss_function == REC_K)) {
00302             for (i = 0; i < sample.examples[0].x.totdoc; i++) {
00303                 if (sample.examples[0].y.Class[i] > 0)
00304                     sample.examples[0].y.Class[i] = 0.5 * 100.0 / (numn + nump);
00305                 else
00306                     sample.examples[0].y.Class[i] = -0.5 * 100.0
00307                             / (numn + nump);
00308             }
00309         }
00310         /* change label value for easy computation of rankmetrics (i.e. ROC-area) */
00311         if (struct_parm.loss_function == SWAPPEDPAIRS) {
00312             for (i = 0; i < sample.examples[0].x.totdoc; i++) {
00313                 /*      if(sample.examples[0].y.class[i]>0)
00314                  sample.examples[0].y.class[i]=numn*0.5;
00315                  else
00316                  sample.examples[0].y.class[i]=-nump*0.5; */
00317                 if (sample.examples[0].y.Class[i] > 0)
00318                     sample.examples[0].y.Class[i] = 0.5 * 100.0 / nump;
00319                 else
00320                     sample.examples[0].y.Class[i] = -0.5 * 100.0 / numn;
00321                 /*              cerr << sample.examples[0].x.doc[i]->fvec->words[0].weight
00322                  << '\t' << sample.examples[0].y.Class[i] << endl;*/
00323             }
00324         }
00325         if (struct_parm.loss_function == AVGPREC) {
00326             for (i = 0; i < sample.examples[0].x.totdoc; i++) {
00327                 if (sample.examples[0].y.Class[i] > 0)
00328                     sample.examples[0].y.Class[i] = numn;
00329                 else
00330                     sample.examples[0].y.Class[i] = -nump;
00331             }
00332         }
00333         cerr << "ALG=" << Alg << endl;
00334         svm_learn_struct_joint(sample, &struct_parm, &learn_parm, &kernel_parm,
00335                        &structmodel, Alg);
00336         //
00337     }
00338 
00339     // Pair learning
00340 
00341     //creates a Doc for a pair of genes by feature-wise multiplication
00342     static DOC* CreateDoc(Sleipnir::CPCL &PCL, size_t iGene, size_t jGene,
00343             size_t iDoc);
00344 
00345     //creates the complete Sample given  Dat of answers and a list of CV Genes
00346     static SAMPLE* CreateSample(Sleipnir::CPCL &PCL, Sleipnir::CDat& Answers,
00347             const vector<string>& CVGenes);
00348 
00349     //Classify  gene pairs from the given CV List
00350     void Classify(Sleipnir::CPCL& PCL, Sleipnir::CDat& Answers,
00351             Sleipnir::CDat& Values, Sleipnir::CDat& Counts,
00352             const vector<string>& CVGenes);
00353 
00354     //Will classify all pairs EXCEPT ones that touch the CV list
00355     void ClassifyAll(Sleipnir::CPCL& PCL, Sleipnir::CDat& Values,
00356             Sleipnir::CDat& Counts, const vector<string>& CVGenes);
00357     //Same as above but won't keep track of the counts, saves on memory
00358     void ClassifyAll(Sleipnir::CPCL& PCL, Sleipnir::CDat& Values, const vector<
00359             string>& CVGenes);
00360     bool parms_check();
00361     bool initialize();
00362     
00363     
00364     //Pair & Multiple dabs learning
00365     static bool CreateDoc(vector<string>& vecstrDatasets,
00366                   vector<SVMLabelPair*>& vecLabels,
00367                   const vector<string>& LabelsGene,
00368                   Sleipnir::CDat::ENormalize eNormalize = Sleipnir::CDat::ENormalizeNone);
00369     
00370     static SAMPLE* CreateSample(vector<SVMLabelPair*>& SVMLabels);  
00371     void Classify(Sleipnir::CDat &Results,
00372               vector<SVMLabelPair*>& SVMLabels);
00373     
00374     // free the sample but don't free the Docs
00375     static void FreeSample_leave_Doc(SAMPLE s);
00376 
00377     // functions to convert probablity
00378     void sigmoid_train(Sleipnir::CDat& Results, vector<SVMLabelPair*>& SVMLabels, float& A, float& B);
00379     void sigmoid_predict(Sleipnir::CDat& Results, vector<SVMLabelPair*>& SVMLabels, float A, float B);
00380 
00381     // read in a SVM model file that's only has the w vector written out for linear kernel
00382     void ReadModelLinear(char* model_file) {
00383       FreeModel();
00384       structmodel = read_struct_model_w_linear(model_file, &struct_parm);
00385     }
00386     
00387     STRUCTMODEL read_struct_model_w_linear(char *file, STRUCT_LEARN_PARM *sparm);
00388 };
00389 }
00390 
00391 #endif // NO_SVM_SVMPERF
00392 #endif // SVMPERF_H