Sleipnir
src/svmstruct.h
00001 /*****************************************************************************
00002 * This file is provided under the Creative Commons Attribution 3.0 license.
00003 *
00004 * You are free to share, copy, distribute, transmit, or adapt this work
00005 * PROVIDED THAT you attribute the work to the authors listed below.
00006 * For more information, please see the following web page:
00007 * http://creativecommons.org/licenses/by/3.0/
00008 *
00009 * This file is a component of the Sleipnir library for functional genomics,
00010 * authored by:
00011 * Curtis Huttenhower (chuttenh@princeton.edu)
00012 * Mark Schroeder
00013 * Maria D. Chikina
00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00015 *
00016 * If you use this library, the included executable tools, or any related
00017 * code in your work, please cite the following publication:
00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00019 * Olga G. Troyanskaya.
00020 * "The Sleipnir library for computational functional genomics"
00021 *****************************************************************************/
00022 
00023 #ifndef NO_SVM_STRUCT
00024 #ifndef SVMSTRUCTI_H
00025 #define SVMSTRUCTI_H
00026 #include "pclset.h"
00027 #include "meta.h"
00028 #include "dat.h"
00029 #ifndef NO_SVM_STRUCT
00030 #define SVMSTRUCT_H
00031 extern "C" {
00032 
00033 #define class Class
00034 
00035 #include <svm_multiclass/svm_light/svm_common.h>
00036 #include <svm_multiclass/svm_light/svm_learn.h>
00037 #include <svm_multiclass/svm_struct_api_types.h>
00038 #include <svm_multiclass/svm_struct/svm_struct_common.h>
00039 #include <svm_multiclass/svm_struct_api.h>
00040 #include <svm_multiclass/svm_struct/svm_struct_learn.h>
00041 #undef class
00042     //#include "svm_struct_api.h"
00043 
00044 }
00045 #endif
00046 
00047 #include <stdio.h>
00048 using namespace Sleipnir;
00049 using namespace std;
00050 
00051 /* removed to support cygwin */
00052 //#include <execinfo.h>
00053 
00054 namespace SVMArc {
00055     class SVMLabel {
00056     public:
00057         string GeneName;
00058         size_t Target; //Save single integer label; used for single label classification (0-1, or multiclass)
00059         vector<char> TargetM; //Save multiple labels; used for hierarchical multi-label classification;
00060 
00061         size_t index;
00062         bool hasIndex;
00063         SVMLabel(std::string name, size_t target) {
00064             GeneName = name;
00065             Target = target;
00066             hasIndex = false;
00067             index = -1;
00068         }
00069 
00070         SVMLabel(std::string name, vector<char> cl) {
00071             GeneName = name;
00072             TargetM = cl;
00073             hasIndex = false;
00074             index = -1;
00075         }
00076         SVMLabel() {
00077             GeneName = "";
00078             Target = 0;
00079         }
00080         void SetIndex(size_t i) {
00081             index = i;
00082             hasIndex = true;
00083         }
00084     };
00085 
00086     class Result {
00087     public:
00088         std::string GeneName;
00089         int Target; //for single label prediction
00090         int Value; //for single label prediction
00091         vector<char> TargetM;//for multi label prediction
00092         vector<char> ValueM; //for multi label prediction
00093         vector<double> Scores;
00094         int num_class;
00095         int CVround;
00096         int Rank;
00097         Result() {
00098             GeneName = "";
00099             Target = 0;
00100             Value = -1;
00101         }
00102 
00103         Result(std::string name, int cv = -1) {
00104             GeneName = name;
00105             Target = 0;
00106             Value = 0;
00107             CVround = cv;
00108             Rank = -1;
00109             num_class = 0;
00110 
00111         }
00112         string toString() {
00113             stringstream ss;
00114             ss << GeneName << '\t' << Target << '\t' << Value << '\t' << "CV"
00115                 << CVround;
00116             if (Rank != -1) {
00117                 ss << '\t' << Rank;
00118             }
00119             return ss.str();
00120         }
00121         string toStringMC() {
00122             stringstream ss;
00123             ss << GeneName << '\t' << Target << '\t' << Value << '\t';
00124             for(size_t j=1;j<=num_class;j++)
00125                 ss << Scores[j]<<'\t';
00126             return ss.str();
00127         }
00128         string toStringTREE(map<int, string>* ponto_map_rev, int returnindex) {
00129             stringstream ss;
00130             int mark=1;
00131             ss << GeneName << '\t';
00132             for(size_t j=0;j<num_class;j++){
00133                 if(TargetM[j]==1)
00134                     if(mark){
00135                         if(returnindex)
00136                             ss<<j;
00137                         else
00138                             ss <<(*ponto_map_rev)[j];
00139                         mark = 0;
00140                     }
00141                     else
00142                         ss <<','<<(*ponto_map_rev)[j];
00143             }
00144             if(mark==1)
00145                 ss<<"??"<<'\t';
00146             else
00147                 ss<<'\t';
00148 
00149             mark=1;
00150             for(size_t j=0;j<num_class;j++){
00151                 if(ValueM[j]==1)
00152                     if(mark){
00153                         if(returnindex)
00154                             ss<<j;
00155                         else
00156                             ss <<(*ponto_map_rev)[j];
00157                         mark = 0;
00158                     }
00159                     else
00160                         ss <<','<<(*ponto_map_rev)[j];
00161             }
00162             if(mark)
00163                 ss<<"??";
00164             ss <<'\t';
00165             for(size_t j=0;j<num_class;j++)
00166                 ss << Scores[j]<<'\t';
00167             return ss.str();
00168         }
00169     };
00170 
00171     enum EFilter {
00172         EFilterInclude = 0, EFilterExclude = EFilterInclude + 1,
00173     };
00174 
00175     class CSVMSTRUCTBASE{
00176         /* This base class is solely intended to serve as a common template for different SVM Struct implementations
00177         A few required functions are not defined here because their parameter type or return type has to differ 
00178         among different implementations, but I listed them in comments. */
00179     public:
00180         virtual vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels) = 0;
00181         virtual void SetTradeoff(double tradeoff)=0;
00182         virtual void SetLossFunction(size_t loss_f)=0;
00183         virtual void SetLearningAlgorithm(int alg)=0;
00184         virtual void UseSlackRescaling()=0;
00185         virtual void UseMarginRescaling()=0;
00186         virtual void ReadModel(char* model_file)=0;
00187         virtual void WriteModel(char* model_file)=0;
00188         virtual vector<SVMLabel> ReadLabels(ifstream & ifsm)=0;
00189         virtual void SetVerbosity(size_t V)=0;
00190         virtual bool parms_check() = 0;
00191         virtual bool initialize() = 0;
00192 
00193         /*The following functions should also be implemented
00194         SAMPLE* CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels);
00195         static void FreeSample(sample s)
00196         void Learn(SAMPLE &sample)
00197         */
00198     };
00199 
00200 
00201 
00202 
00203     //this class encapsulates the model and parameters and has no associated data
00204 
00205 
00206     //class for SVMStruct
00207     class CSVMSTRUCTMC : CSVMSTRUCTBASE{
00208 
00209     public:
00210         LEARN_PARM learn_parm;
00211         KERNEL_PARM kernel_parm;
00212         STRUCT_LEARN_PARM struct_parm;
00213         STRUCTMODEL structmodel;
00214         int Alg;
00215         CSVMSTRUCTMC() {
00216             initialize();
00217             //set_struct_verbosity(5);
00218         }
00219 
00220         void SetLossFunction(size_t loss_f) {
00221             struct_parm.loss_function = loss_f;
00222         }
00223 
00224         void SetTradeoff(double tradeoff) {
00225             struct_parm.C = tradeoff;
00226         }
00227         void SetLearningAlgorithm(int alg) {
00228             Alg = alg;
00229         }
00230         void SetKernel(int K) {
00231             kernel_parm.kernel_type = K;
00232         }
00233         void SetPolyD(int D) {
00234             kernel_parm.poly_degree = D;
00235         }
00236 
00237         //void UseCPSP() {
00238         //  Alg = 9;
00239         //  struct_parm.preimage_method = 2;
00240         //  struct_parm.sparse_kernel_size = 500;
00241         //  struct_parm.bias = 0;
00242         //}
00243 
00244         //void SetRBFGamma(double g) {
00245         //  kernel_parm.rbf_gamma = g;
00246         //  UseCPSP();
00247         //}
00248 
00249         void UseSlackRescaling() {
00250             struct_parm.loss_type = SLACK_RESCALING;
00251         }
00252 
00253         void UseMarginRescaling() {
00254             struct_parm.loss_type = MARGIN_RESCALING;
00255         }
00256 
00257 
00258 
00259         void ReadModel(char* model_file) {
00260             structmodel = read_struct_model(model_file, &struct_parm);
00261             if(structmodel.svm_model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */
00262                 /* compute weight vector */
00263                 add_weight_vector_to_linear_model(structmodel.svm_model);
00264                 structmodel.w=structmodel.svm_model->lin_weights;
00265             }
00266         }
00267 
00268         void WriteModel(char* model_file) {
00269             if (kernel_parm.kernel_type == LINEAR) {
00270                 ofstream ofsm;
00271                 ofsm.open(model_file);
00272                 for (size_t i = 0; i < structmodel.sizePsi; i++) {
00273                     ofsm << structmodel.w[i+1] << endl;
00274                 }
00275             } else {
00276                 write_struct_model(model_file, &structmodel, &struct_parm);
00277             }
00278         }
00279 
00280         void WriteWeights(ostream& osm) {
00281             osm << structmodel.w[0];
00282             for (size_t i = 1; i < structmodel.sizePsi + 1; i++)
00283                 osm << '\t' << structmodel.w[i];
00284             osm << endl;
00285         }
00286 
00287         static void FreePattern(pattern x) {
00288             free_pattern(x);
00289         }
00290 
00291         static void FreeLabel(label y) {
00292             free_label(y);
00293         }
00294 
00295         void FreeModel() {
00296             free_struct_model(structmodel);
00297         }
00298 
00299         static void FreeSample(sample s) {
00300             free_struct_sample(s);
00301         }
00302 
00303         static void FreeDoc(DOC* pDoc) {
00304             free_example(pDoc, true);
00305         }
00306         void SetVerbosity(size_t V);
00307 
00308         //static members process data
00309         //single gene predictions
00310 
00311 
00312         //creates a Doc for a given gene index in a single microarray
00313         static DOC* CreateDoc(Sleipnir::CPCL &PCL, size_t iGene, size_t iDoc);
00314 
00315 
00316         //read in labels
00317         vector<SVMLabel> ReadLabels(ifstream & ifsm);
00318 
00319         //Creates a sample using a single PCL and SVMlabels Looks up genes by name.
00320         SAMPLE
00321             * CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels);
00322 
00323         //Classify single genes
00324         vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels);
00325 
00326         //MEMBER functions wraps learning
00327         void Learn(SAMPLE &sample) {
00328             cerr << "SLACK NORM =" << struct_parm.slack_norm << endl;
00329             /*  if (kernel_parm.kernel_type==CUSTOM)
00330             svm_learn_struct_joint_custom(sample, &struct_parm, &learn_parm, &kernel_parm, &structmodel);
00331             else*/
00332 
00333 
00334             cerr << "ALG=" << Alg << endl;
00335 
00336             if(Alg == 0)
00337                 svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG);
00338             else if(Alg == 1)
00339                 svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
00340             else if(Alg == 2)
00341                 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
00342             else if(Alg == 3)
00343                 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
00344             else if(Alg == 4)
00345                 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
00346             else if(Alg == 9)
00347                 svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
00348             else
00349                 exit(1);
00350             //
00351         }
00352 
00353         struct SortResults {
00354 
00355             bool operator()(const Result& rOne, const Result & rTwo) const {
00356                 return (rOne.Value < rTwo.Value);
00357             }
00358         };
00359 
00360         size_t PrintResults(vector<Result> vecResults, ofstream & ofsm) {
00361             sort(vecResults.begin(), vecResults.end(), SortResults());
00362             int LabelVal;
00363             for (size_t i = 0; i < vecResults.size(); i++) {
00364                 ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t'
00365                     << vecResults[i].Value<<'\t';
00366                 for(size_t j=1;j<=vecResults[i].num_class;j++)
00367                     ofsm << vecResults[i].Scores[j]<<'\t';
00368                 ofsm<< endl;
00369 
00370             }
00371         };
00372 
00373         bool parms_check();
00374         bool initialize();
00375 
00376 
00377 
00378         // free the sample but don't free the Docs
00379         static void FreeSample_leave_Doc(SAMPLE s);
00380 
00381 
00382 
00383         STRUCTMODEL read_struct_model_w_linear(char *file, STRUCT_LEARN_PARM *sparm);
00384     };
00385 
00386 
00387 };
00388 
00389 
00390 #endif // NO_SVM_SVMSTRUCT
00391 #endif // SVMSTRUCT_H