Sleipnir
src/libsvm.h
00001 /*****************************************************************************
00002  * This file is provided under the Creative Commons Attribution 3.0 license.
00003  *
00004  * You are free to share, copy, distribute, transmit, or adapt this work
00005  * PROVIDED THAT you attribute the work to the authors listed below.
00006  * For more information, please see the following web page:
00007  * http://creativecommons.org/licenses/by/3.0/
00008  *
00009  * This file is a component of the Sleipnir library for functional genomics,
00010  * authored by:
00011  * Curtis Huttenhower (chuttenh@princeton.edu)
00012  * Mark Schroeder
00013  * Maria D. Chikina
00014  * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00015  *
00016  * If you use this library, the included executable tools, or any related
00017  * code in your work, please cite the following publication:
00018  * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00019  * Olga G. Troyanskaya.
00020  * "The Sleipnir library for computational functional genomics"
00021  *****************************************************************************/
00022 
00023 #ifndef NO_LIBSVM
00024 #ifndef LIBSVMI_H
00025 #define LIBSVMI_H
00026 #include "pclset.h"
00027 #include "meta.h"
00028 #include "dat.h"
00029 
00030 #include <stdio.h>
00031 
00032 /* removed to support cygwin */
00033 //#include <execinfo.h>
00034 
00035 //#include <svm.h>
00036 
00037 namespace LIBSVM {
00038 
00039 
00040   extern "C" {
00041 #define class Class2
00042 #include <libsvm/svm.h>
00043 #undef class
00044 
00045   }
00046 
00047   typedef struct sample { /* a sample is a set of examples */
00048     size_t     n;            /* n is the total number of examples */
00049     size_t  numFeatures; 
00050     struct svm_problem *problems;
00051     sample() {
00052       n = 0;
00053       numFeatures = 0;
00054       problems = NULL;
00055     }
00056 
00057     ~sample(){
00058       //no destructor for problem struct
00059       free(problems->y);
00060       free(problems->x);
00061       problems = NULL;
00062     }
00063   } SAMPLE;
00064 
00065 
00066   class SVMLabel {
00067     public:
00068       string GeneName;
00069       double Target;
00070       size_t index;
00071       bool hasIndex;
00072 
00073       SVMLabel(std::string name, double target) {
00074         GeneName = name;
00075         Target = target;
00076         hasIndex = false;
00077         index = -1;
00078       }
00079 
00080       SVMLabel() {
00081         GeneName = "";
00082         Target = 0;
00083       }
00084       void SetIndex(size_t i) {
00085         index = i;
00086         hasIndex = true;
00087       }
00088   };
00089 
00090   class Result {
00091     public:
00092       std::string GeneName;
00093       double Target;
00094       double Value;
00095       int CVround;
00096       int Rank;
00097       Result() {
00098         GeneName = "";
00099         Target = 0;
00100         Value = Sleipnir::CMeta::GetNaN();
00101       }
00102 
00103       Result(std::string name, int cv = -1) {
00104         GeneName = name;
00105         Target = 0;
00106         Value = 0;
00107         CVround = cv;
00108         Rank = -1;
00109       }
00110       string toString() {
00111         stringstream ss;
00112         ss << GeneName << '\t' << Target << '\t' << Value << '\t' << "CV"
00113           << CVround;
00114         if (Rank != -1) {
00115           ss << '\t' << Rank;
00116         }
00117         return ss.str();
00118       }
00119 
00120   };
00121 
00122   enum EFilter {
00123     EFilterInclude = 0, EFilterExclude = EFilterInclude + 1,
00124   };
00125 
00126   //this class encapsulates the model and parameters and has no associated data
00127 
00128   class CLIBSVM {
00129     public:
00130       struct svm_model* model;
00131       struct svm_parameter parm;
00132       int balance;
00133 
00134       static struct svm_node *x_space;
00135 
00136       CLIBSVM() {
00137         initialize();
00138       }
00139 
00140       ~CLIBSVM() {
00141         svm_free_and_destroy_model( &model );
00142         model = NULL;
00143       }
00144 
00145       void SetBalance(int bal){
00146         balance = bal;
00147       }
00148 
00149       void SetSVMType(int type) {
00150         parm.svm_type = type;
00151       }
00152 
00153       void SetTradeoff(double tradeoff) {
00154         parm.C = tradeoff; //TODO: only applicable for vanilla svm
00155       }
00156 
00157       void SetKernel(int K) {
00158         parm.kernel_type = K;
00159       }
00160 
00161       void SetPolyD(int D) {
00162         parm.degree = D;
00163       }
00164 
00165       void SetRBFGamma(double g) {
00166         parm.gamma = g;
00167       }
00168 
00169       void SetNu(double nu) {
00170         parm.nu = nu;
00171       }
00172 
00173       void ReadModel(char* model_file) {
00174         FreeModel();
00175         model = svm_load_model(model_file); 
00176       }
00177 
00178       void FreeModel() {
00179         svm_free_and_destroy_model(&model);
00180       }
00181 
00182       void WriteModel(char* model_file) {
00183         svm_save_model(model_file, model);
00184       }
00185 
00186 
00187       //static members process data
00188       //
00189 
00190       static void SetXSpace(Sleipnir::CPCL& PCL);
00191 
00192       //single gene predictions
00193 
00194       //TODO: add functions to handle PCL files
00195 
00196       //Creates a sample of svm_problems using a single PCL and SVMlabels Looks up genes by name.
00197       static SAMPLE* CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels);
00198 
00199       //TODO: Same as above except creates bootstrap samples and does not duplicate data
00200 
00201       //Creates a sample using a Dat and SVMlabels. Looks up genes by name
00202       static SAMPLE* CreateSample(Sleipnir::CDat& CDat,
00203           vector<SVMLabel> SMVLabels);
00204 
00205       //Classify single genes
00206       vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels);
00207 
00208 
00209       //MEMBER functions wraps learning
00210       void Learn(SAMPLE &sample) {
00211         //only L2 for LibSVM
00212         //cerr << "SLACK NORM =" << struct_parm.slack_norm << endl;
00213         //slack_norm = type of regularization
00214 
00215         //Take care of the labels here
00216         size_t i;
00217         size_t numn, nump;
00218 
00219         struct svm_problem* prob = sample.problems;
00220 
00221         numn = nump = 0;
00222 
00223         for(i = 0; i < sample.n; i++){
00224           if (((*prob).y)[i] > 0){
00225             nump ++;
00226           }else{
00227             numn ++;
00228           }
00229         }
00230 
00231         if (balance) {
00232           cerr << "balancing the weights between postivies and negatives. " << endl;
00233           parm.nr_weight = 2;
00234           parm.weight_label = (int *) realloc(parm.weight_label, sizeof(int)*parm.nr_weight);
00235           parm.weight = (double *) realloc(parm.weight, sizeof(double)*parm.nr_weight);
00236           parm.weight_label[0] = 1;
00237           parm.weight[0] = numn;
00238           parm.weight_label[1] = -1;
00239           parm.weight[1] = nump;
00240         }
00241 
00242         if(parms_check()){
00243           model = svm_train(prob,&parm);
00244         }else{
00245         }
00246         prob = NULL;
00247 
00248       }
00249 
00250       static void PrintSample(SAMPLE s){
00251         PrintProblem(s.problems);
00252       }
00253 
00254       static void PrintProblem(svm_problem *prob){
00255         size_t i, j ;
00256         i = j = 0;
00257 
00258         for(i = 0 ; i < 3 ; i++){
00259           for(j = 0 ; j < 2 ; j ++){
00260             PrintNode((prob->x)[i][j]);
00261           }
00262         }
00263 
00264         return;
00265       }
00266 
00267       static void PrintNode(svm_node node){
00268         cerr << "index: " << node.index << endl;
00269         cerr << "value: " << node.value << endl;
00270       }
00271 
00272 
00273       //no pairwise learning for libSVM wrapper
00274 
00275       bool parms_check();
00276       bool initialize();
00277 
00278       //TODO: functions to convert probablity
00279 
00280   };
00281 }
00282 
00283 #endif // NO_SVM_LIBSVM
00284 #endif // LIBSVM_H