Sleipnir
|
00001 /***************************************************************************** 00002 * This file is provided under the Creative Commons Attribution 3.0 license. 00003 * 00004 * You are free to share, copy, distribute, transmit, or adapt this work 00005 * PROVIDED THAT you attribute the work to the authors listed below. 00006 * For more information, please see the following web page: 00007 * http://creativecommons.org/licenses/by/3.0/ 00008 * 00009 * This file is a component of the Sleipnir library for functional genomics, 00010 * authored by: 00011 * Curtis Huttenhower (chuttenh@princeton.edu) 00012 * Mark Schroeder 00013 * Maria D. Chikina 00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00015 * 00016 * If you use this library, the included executable tools, or any related 00017 * code in your work, please cite the following publication: 00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00019 * Olga G. Troyanskaya. 00020 * "The Sleipnir library for computational functional genomics" 00021 *****************************************************************************/ 00022 00023 #ifndef NO_SVM_STRUCTTREE 00024 #ifndef SVMSTRUCTTREEI_H 00025 #define SVMSTRUCTTREEI_H 00026 #include "pclset.h" 00027 #include "meta.h" 00028 #include "dat.h" 00029 00030 #include <stdio.h> 00031 #include <map> 00032 00033 #ifndef NO_SVM_STRUCT 00034 #define SVMSTRUCT_H 00035 extern "C" { 00036 00037 #define class Class 00038 00039 #include <svm_hierarchy/svm_light/svm_common.h> 00040 #include <svm_hierarchy/svm_light/svm_learn.h> 00041 #include <svm_hierarchy/svm_struct_api_types.h> 00042 #include <svm_hierarchy/svm_struct/svm_struct_common.h> 00043 #include <svm_hierarchy/svm_struct_api.h> 00044 #include <svm_hierarchy/svm_struct/svm_struct_learn.h> 00045 #undef class 00046 00047 } 00048 #endif 00049 00050 #include "svmstruct.h" 00051 00052 /* removed to support cygwin */ 00053 //#include <execinfo.h> 00054 00055 namespace SVMArc { 00056 00057 00058 00059 00060 00061 00062 struct TEMPNODE{ // temporary intermediate data stucture 00063 set<ONTONODE*> children; 00064 }; 00065 00066 //class for SVMStruct 00067 class CSVMSTRUCTTREE : CSVMSTRUCTBASE { 00068 00069 public: 00070 LEARN_PARM learn_parm; 00071 KERNEL_PARM kernel_parm; 00072 STRUCT_LEARN_PARM struct_parm; 00073 STRUCTMODEL structmodel; 00074 map<string,int> onto_map; 00075 map<int, string> onto_map_rev; 00076 00077 00078 int Alg; 00079 CSVMSTRUCTTREE() { 00080 initialize(); 00081 //set_struct_verbosity(5); 00082 } 00083 00084 void SetLossFunction(size_t loss_f) { 00085 struct_parm.loss_function = loss_f; 00086 } 00087 00088 void SetTradeoff(double tradeoff) { 00089 struct_parm.C = tradeoff; 00090 } 00091 void SetLearningAlgorithm(int alg) { 00092 Alg = alg; 00093 } 00094 00095 void SetEpsilon(float eps) { 00096 struct_parm.epsilon = eps; 00097 } 00098 00099 void UseSlackRescaling() { 00100 struct_parm.loss_type = SLACK_RESCALING; 00101 } 00102 00103 void UseMarginRescaling() { 00104 struct_parm.loss_type = MARGIN_RESCALING; 00105 } 00106 00107 void SetNThreads(int n) { 00108 struct_parm.n_threads=n; 00109 } 00110 00111 void ReadModel(char* model_file) { 00112 00113 structmodel = read_struct_model(model_file, &struct_parm); 00114 if(structmodel.svm_model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */ 00115 /* compute weight vector */ 00116 add_weight_vector_to_linear_model(structmodel.svm_model); 00117 structmodel.w=structmodel.svm_model->lin_weights; 00118 } 00119 } 00120 00121 void WriteModel(char* model_file) { 00122 //if (kernel_parm.kernel_type == LINEAR) { 00123 // ofstream ofsm; 00124 // ofsm.open(model_file); 00125 // for (size_t i = 0; i < structmodel.sizePsi; i++) { 00126 // ofsm << structmodel.w[i+1] << endl; 00127 // } 00128 //} else { 00129 write_struct_model(model_file, &structmodel, &struct_parm); 00130 /*}*/ 00131 } 00132 00133 void WriteWeights(ostream& osm) { 00134 osm << structmodel.w[0]; 00135 for (size_t i = 1; i < structmodel.sizePsi + 1; i++) 00136 osm << '\t' << structmodel.w[i]; 00137 osm << endl; 00138 } 00139 00140 static void FreePattern(pattern x) { 00141 free_pattern(x); 00142 } 00143 00144 static void FreeLabel(label y) { 00145 free_label(y); 00146 } 00147 00148 void FreeModel() { 00149 free_struct_model(structmodel); 00150 } 00151 00152 static void FreeSample(sample s) { 00153 free_struct_sample(s); 00154 } 00155 00156 static void FreeDoc(DOC* pDoc) { 00157 free_example(pDoc, true); 00158 } 00159 void SetVerbosity(size_t V); 00160 00161 00162 00163 void ReadOntology(const char* treefile); 00164 //creates a Doc for a given gene index in a single microarray 00165 static DOC* CreateDoc(Sleipnir::CPCL &PCL, size_t iGene, size_t iDoc); 00166 //static DOC* CreateDoc(Sleipnir::CDat& Dat, size_t iGene, size_t iDoc); 00167 //read labels 00168 vector<SVMLabel> ReadLabels(ifstream & ifsm); 00169 void vecsetZero (ONTONODE* node, vector<char>* ybar0,char zero); 00170 void preprocessLabel(vector<char>* multilabels); 00171 00172 void InitializeLikAfterReadLabels(); 00173 //Creates a sample using a single PCL and SVMlabels Looks up genes by name. 00174 SAMPLE* CreateSample(Sleipnir::CPCL &PCL, vector<SVMLabel> SVMLabels); 00175 //SAMPLE* CreateSample(Sleipnir::CDat& Dat, vector<SVMLabel> SVMLabels); 00176 00177 //Classify single genes 00178 vector<Result> Classify(Sleipnir::CPCL& PCL, vector<SVMLabel> SVMLabels); 00179 00180 //MEMBER functions wraps learning 00181 void Learn(SAMPLE &sample) { 00182 //cerr << "SLACK NORM =" << struct_parm.slack_norm << endl; 00183 cerr << "Algorithm " << Alg << " selected."<<endl; 00184 00185 if(Alg == 0) 00186 svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG); 00187 else if(Alg == 1) 00188 svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG); 00189 else if(Alg == 2) 00190 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG); 00191 else if(Alg == 3) 00192 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG); 00193 else if(Alg == 4) 00194 svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG); 00195 else if(Alg == 9) 00196 svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel); 00197 else 00198 exit(1); 00199 // 00200 } 00201 00202 00203 00204 struct SortResults { 00205 bool operator()(const Result& rOne, const Result & rTwo) const { 00206 return (rOne.GeneName < rTwo.GeneName); //sort results by name 00207 } 00208 }; 00209 00210 size_t PrintResults(vector<Result> vecResults, ofstream & ofsm) { 00211 sort(vecResults.begin(), vecResults.end(), SortResults()); 00212 int LabelVal; 00213 for (size_t i = 0; i < vecResults.size(); i++) { 00214 ofsm << vecResults[i].toStringTREE(&onto_map_rev,0)<<endl; 00215 } 00216 }; 00217 00218 00219 bool parms_check(); 00220 bool initialize(); 00221 00222 00223 00224 // free the sample but don't free the Docs 00225 static void FreeSample_leave_Doc(SAMPLE s); 00226 00227 00228 00229 STRUCTMODEL read_struct_model_w_linear(char *file, STRUCT_LEARN_PARM *sparm); 00230 }; 00231 00232 00233 00234 00235 00236 }; 00237 00238 00239 #endif // NO_SVM_SVMSTRUCTTREE 00240 #endif // SVMSTRUCTTREE_H