Sleipnir
|
00001 /***************************************************************************** 00002 * This file is provided under the Creative Commons Attribution 3.0 license. 00003 * 00004 * You are free to share, copy, distribute, transmit, or adapt this work 00005 * PROVIDED THAT you attribute the work to the authors listed below. 00006 * For more information, please see the following web page: 00007 * http://creativecommons.org/licenses/by/3.0/ 00008 * 00009 * This file is a component of the Sleipnir library for functional genomics, 00010 * authored by: 00011 * Curtis Huttenhower (chuttenh@princeton.edu) 00012 * Mark Schroeder 00013 * Maria D. Chikina 00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00015 * 00016 * If you use this library, the included executable tools, or any related 00017 * code in your work, please cite the following publication: 00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00019 * Olga G. Troyanskaya. 00020 * "The Sleipnir library for computational functional genomics" 00021 *****************************************************************************/ 00022 00023 #ifndef NO_SVM_STRUCT 00024 #ifndef SVMSTRUCTI_H 00025 #define SVMSTRUCTI_H 00026 #include "pclset.h" 00027 #include "meta.h" 00028 #include "dat.h" 00029 #ifndef NO_SVM_STRUCT 00030 #define SVMSTRUCT_H 00031 extern "C" { 00032 00033 #define class Class 00034 00035 #include <svm_multiclass/svm_light/svm_common.h> 00036 #include <svm_multiclass/svm_light/svm_learn.h> 00037 #include <svm_multiclass/svm_struct_api_types.h> 00038 #include <svm_multiclass/svm_struct/svm_struct_common.h> 00039 #include <svm_multiclass/svm_struct_api.h> 00040 #include <svm_multiclass/svm_struct/svm_struct_learn.h> 00041 #undef class 00042 //#include "svm_struct_api.h" 00043 00044 } 00045 #endif 00046 00047 #include <stdio.h> 00048 using namespace Sleipnir; 00049 using namespace std; 00050 00051 /* removed to support cygwin */ 00052 //#include <execinfo.h> 00053 00054 namespace SVMArc { 00055 class SVMLabel { 00056 public: 00057 string GeneName; 00058 size_t Target; //Save single integer label; used for single label classification (0-1, or multiclass) 00059 vector<char> TargetM; //Save multiple labels; used for hierarchical multi-label classification; 00060 00061 size_t index; 00062 bool hasIndex; 00063 SVMLabel(std::string name, size_t target) { 00064 GeneName = name; 00065 Target = target; 00066 hasIndex = false; 00067 index = -1; 00068 } 00069 00070 SVMLabel(std::string name, vector<char> cl) { 00071 GeneName = name; 00072 TargetM = cl; 00073 hasIndex = false; 00074 index = -1; 00075 } 00076 SVMLabel() { 00077 GeneName = ""; 00078 Target = 0; 00079 } 00080 void SetIndex(size_t i) { 00081 index = i; 00082 hasIndex = true; 00083 } 00084 }; 00085 00086 class Result { 00087 public: 00088 std::string GeneName; 00089 int Target; //for single label prediction 00090 int Value; //for single label prediction 00091 vector<char> TargetM;//for multi label prediction 00092 vector<char> ValueM; //for multi label prediction 00093 vector<double> Scores; 00094 int num_class; 00095 int CVround; 00096 int Rank; 00097 Result() { 00098 GeneName = ""; 00099 Target = 0; 00100 Value = -1; 00101 } 00102 00103 Result(std::string name, int cv = -1) { 00104 GeneName = name; 00105 Target = 0; 00106 Value = 0; 00107 CVround = cv; 00108 Rank = -1; 00109 num_class = 0; 00110 00111 } 00112 string toString() { 00113 stringstream ss; 00114 ss << GeneName << '\t' << Target << '\t' << Value << '\t' << "CV" 00115 << CVround; 00116 if (Rank != -1) { 00117 ss << '\t' << Rank; 00118 } 00119 return ss.str(); 00120 } 00121 string toStringMC() { 00122 stringstream ss; 00123 ss << GeneName << '\t' << Target << '\t' << Value << '\t'; 00124 for(size_t j=1;j<=num_class;j++) 00125 ss << Scores[j]<<'\t'; 00126 return ss.str(); 00127 } 00128 string toStringTREE(map<int, string>* ponto_map_rev, int returnindex) { 00129 stringstream ss; 00130 int mark=1; 00131 ss << GeneName << '\t'; 00132 for(size_t j=0;j<num_class;j++){ 00133 if(TargetM[j]==1) 00134 if(mark){ 00135 if(returnindex) 00136 ss<<j; 00137 else 00138 ss <<(*ponto_map_rev)[j]; 00139 mark = 0; 00140 } 00141 else 00142 ss <<','<<(*ponto_map_rev)[j]; 00143 } 00144 if(mark==1) 00145 ss<<"??"<<'\t'; 00146 else 00147 ss<<'\t'; 00148 00149 mark=1; 00150 for(size_t j=0;j<num_class;j++){ 00151 if(ValueM[j]==1) 00152 if(mark){ 00153 if(returnindex) 00154 ss<<j; 00155 else 00156 ss <<(*ponto_map_rev)[j]; 00157 mark = 0; 00158 } 00159 else 00160 ss <<','<<(*ponto_map_rev)[j]; 00161 } 00162 if(mark) 00163 ss<<"??"; 00164 ss <<'\t'; 00165 for(size_t j=0;j<num_class;j++) 00166 ss << Scores[j]<<'\t'; 00167 return ss.str(); 00168 } 00169 }; 00170 00171 enum EFilter { 00172 EFilterInclude = 0, EFilterExclude = EFilterInclude + 1, 00173 }; 00174 00175 class CSVMSTRUCTBASE{ 00176 /* This base class is solely intended to serve as a common template for different SVM Struct implementations 00177 A few required functions are not defined here because their parameter type or return type has to differ 00178 among different implementations, but I listed them in comments. */ 00179 public: 00180 virtual vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels) = 0; 00181 virtual void SetTradeoff(double tradeoff)=0; 00182 virtual void SetLossFunction(size_t loss_f)=0; 00183 virtual void SetLearningAlgorithm(int alg)=0; 00184 virtual void UseSlackRescaling()=0; 00185 virtual void UseMarginRescaling()=0; 00186 virtual void ReadModel(char* model_file)=0; 00187 virtual void WriteModel(char* model_file)=0; 00188 virtual vector<SVMLabel> ReadLabels(ifstream & ifsm)=0; 00189 virtual void SetVerbosity(size_t V)=0; 00190 virtual bool parms_check() = 0; 00191 virtual bool initialize() = 0; 00192 00193 /*The following functions should also be implemented 00194 SAMPLE* CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels); 00195 static void FreeSample(sample s) 00196 void Learn(SAMPLE &sample) 00197 */ 00198 }; 00199 00200 00201 00202 00203 //this class encapsulates the model and parameters and has no associated data 00204 00205 00206 //class for SVMStruct 00207 class CSVMSTRUCTMC : CSVMSTRUCTBASE{ 00208 00209 public: 00210 LEARN_PARM learn_parm; 00211 KERNEL_PARM kernel_parm; 00212 STRUCT_LEARN_PARM struct_parm; 00213 STRUCTMODEL structmodel; 00214 int Alg; 00215 CSVMSTRUCTMC() { 00216 initialize(); 00217 //set_struct_verbosity(5); 00218 } 00219 00220 void SetLossFunction(size_t loss_f) { 00221 struct_parm.loss_function = loss_f; 00222 } 00223 00224 void SetTradeoff(double tradeoff) { 00225 struct_parm.C = tradeoff; 00226 } 00227 void SetLearningAlgorithm(int alg) { 00228 Alg = alg; 00229 } 00230 void SetKernel(int K) { 00231 kernel_parm.kernel_type = K; 00232 } 00233 void SetPolyD(int D) { 00234 kernel_parm.poly_degree = D; 00235 } 00236 00237 //void UseCPSP() { 00238 // Alg = 9; 00239 // struct_parm.preimage_method = 2; 00240 // struct_parm.sparse_kernel_size = 500; 00241 // struct_parm.bias = 0; 00242 //} 00243 00244 //void SetRBFGamma(double g) { 00245 // kernel_parm.rbf_gamma = g; 00246 // UseCPSP(); 00247 //} 00248 00249 void UseSlackRescaling() { 00250 struct_parm.loss_type = SLACK_RESCALING; 00251 } 00252 00253 void UseMarginRescaling() { 00254 struct_parm.loss_type = MARGIN_RESCALING; 00255 } 00256 00257 00258 00259 void ReadModel(char* model_file) { 00260 structmodel = read_struct_model(model_file, &struct_parm); 00261 if(structmodel.svm_model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */ 00262 /* compute weight vector */ 00263 add_weight_vector_to_linear_model(structmodel.svm_model); 00264 structmodel.w=structmodel.svm_model->lin_weights; 00265 } 00266 } 00267 00268 void WriteModel(char* model_file) { 00269 if (kernel_parm.kernel_type == LINEAR) { 00270 ofstream ofsm; 00271 ofsm.open(model_file); 00272 for (size_t i = 0; i < structmodel.sizePsi; i++) { 00273 ofsm << structmodel.w[i+1] << endl; 00274 } 00275 } else { 00276 write_struct_model(model_file, &structmodel, &struct_parm); 00277 } 00278 } 00279 00280 void WriteWeights(ostream& osm) { 00281 osm << structmodel.w[0]; 00282 for (size_t i = 1; i < structmodel.sizePsi + 1; i++) 00283 osm << '\t' << structmodel.w[i]; 00284 osm << endl; 00285 } 00286 00287 static void FreePattern(pattern x) { 00288 free_pattern(x); 00289 } 00290 00291 static void FreeLabel(label y) { 00292 free_label(y); 00293 } 00294 00295 void FreeModel() { 00296 free_struct_model(structmodel); 00297 } 00298 00299 static void FreeSample(sample s) { 00300 free_struct_sample(s); 00301 } 00302 00303 static void FreeDoc(DOC* pDoc) { 00304 free_example(pDoc, true); 00305 } 00306 void SetVerbosity(size_t V); 00307 00308 //static members process data 00309 //single gene predictions 00310 00311 00312 //creates a Doc for a given gene index in a single microarray 00313 static DOC* CreateDoc(Sleipnir::CPCL &PCL, size_t iGene, size_t iDoc); 00314 00315 00316 //read in labels 00317 vector<SVMLabel> ReadLabels(ifstream & ifsm); 00318 00319 //Creates a sample using a single PCL and SVMlabels Looks up genes by name. 00320 SAMPLE 00321 * CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels); 00322 00323 //Classify single genes 00324 vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels); 00325 00326 //MEMBER functions wraps learning 00327 void Learn(SAMPLE &sample) { 00328 cerr << "SLACK NORM =" << struct_parm.slack_norm << endl; 00329 /* if (kernel_parm.kernel_type==CUSTOM) 00330 svm_learn_struct_joint_custom(sample, &struct_parm, &learn_parm, &kernel_parm, &structmodel); 00331 else*/ 00332 00333 00334 cerr << "ALG=" << Alg << endl; 00335 00336 if(Alg == 0) 00337 svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG); 00338 else if(Alg == 1) 00339 svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG); 00340 else if(Alg == 2) 00341 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG); 00342 else if(Alg == 3) 00343 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG); 00344 else if(Alg == 4) 00345 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG); 00346 else if(Alg == 9) 00347 svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel); 00348 else 00349 exit(1); 00350 // 00351 } 00352 00353 struct SortResults { 00354 00355 bool operator()(const Result& rOne, const Result & rTwo) const { 00356 return (rOne.Value < rTwo.Value); 00357 } 00358 }; 00359 00360 size_t PrintResults(vector<Result> vecResults, ofstream & ofsm) { 00361 sort(vecResults.begin(), vecResults.end(), SortResults()); 00362 int LabelVal; 00363 for (size_t i = 0; i < vecResults.size(); i++) { 00364 ofsm << vecResults[i].GeneName << '\t' << vecResults[i].Target << '\t' 00365 << vecResults[i].Value<<'\t'; 00366 for(size_t j=1;j<=vecResults[i].num_class;j++) 00367 ofsm << vecResults[i].Scores[j]<<'\t'; 00368 ofsm<< endl; 00369 00370 } 00371 }; 00372 00373 bool parms_check(); 00374 bool initialize(); 00375 00376 00377 00378 // free the sample but don't free the Docs 00379 static void FreeSample_leave_Doc(SAMPLE s); 00380 00381 00382 00383 STRUCTMODEL read_struct_model_w_linear(char *file, STRUCT_LEARN_PARM *sparm); 00384 }; 00385 00386 00387 }; 00388 00389 00390 #endif // NO_SVM_SVMSTRUCT 00391 #endif // SVMSTRUCT_H