Sleipnir
|
00001 /***************************************************************************** 00002 * This file is provided under the Creative Commons Attribution 3.0 license. 00003 * 00004 * You are free to share, copy, distribute, transmit, or adapt this work 00005 * PROVIDED THAT you attribute the work to the authors listed below. 00006 * For more information, please see the following web page: 00007 * http://creativecommons.org/licenses/by/3.0/ 00008 * 00009 * This file is a component of the Sleipnir library for functional genomics, 00010 * authored by: 00011 * Curtis Huttenhower (chuttenh@princeton.edu) 00012 * Mark Schroeder 00013 * Maria D. Chikina 00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00015 * 00016 * If you use this library, the included executable tools, or any related 00017 * code in your work, please cite the following publication: 00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00019 * Olga G. Troyanskaya. 00020 * "The Sleipnir library for computational functional genomics" 00021 *****************************************************************************/ 00022 #ifndef BAYESNETFNI_H 00023 #define BAYESNETFNI_H 00024 00025 namespace Sleipnir { 00026 00027 class CBayesNetMinimal; 00028 00029 #ifndef NO_SMILE 00030 00031 class CBayesNetFNNode { 00032 protected: 00033 friend class CBayesNetFN; 00034 friend class CBayesNetFNImpl; 00035 00036 static const char c_szType[]; 00037 00038 static CBayesNetFNNode* Open( DSL_node* ); 00039 00040 const std::string& GetName( ) const; 00041 unsigned char GetParameters( ) const; 00042 void Reverse( ); 00043 bool Save( DSL_node* ) const; 00044 bool Learn( const std::vector<size_t>& ); 00045 00046 virtual const char* GetType( ) const = 0; 00047 virtual void Randomize( ) = 0; 00048 virtual CBayesNetFNNode* New( DSL_node* ) const = 0; 00049 virtual bool Learn( const IDataset*, size_t, size_t ) = 0; 00050 virtual bool Evaluate( float, std::vector<float>& ) const = 0; 00051 00052 virtual bool IsContinuous( ) const { 00053 00054 return true; } 00055 00056 std::string m_strName; 00057 const char* m_szType; 00058 CFullMatrix<float> m_Params; 00059 }; 00060 00061 class CBayesNetFNNodeDiscrete : protected CBayesNetFNNode { 00062 protected: 00063 friend class CBayesNetFNNode; 00064 00065 void Randomize( ); 00066 bool Learn( const IDataset*, size_t, size_t ); 00067 bool Evaluate( float, std::vector<float>& ) const; 00068 00069 CBayesNetFNNode* New( DSL_node* pNode ) const { 00070 00071 return new CBayesNetFNNodeDiscrete( ); } 00072 00073 const char* GetType( ) const { 00074 00075 return "discrete"; } 00076 00077 bool IsContinuous( ) const { 00078 00079 return false; } 00080 }; 00081 00082 class CBayesNetFNNodeGaussian : protected CBayesNetFNNode { 00083 protected: 00084 friend class CBayesNetFNNode; 00085 00086 static const size_t c_iMu = 0; 00087 static const size_t c_iSigma = 1; 00088 00089 void Randomize( ); 00090 bool Learn( const IDataset*, size_t, size_t ); 00091 bool Evaluate( float, std::vector<float>& ) const; 00092 00093 CBayesNetFNNode* New( DSL_node* pNode ) const { 00094 00095 return new CBayesNetFNNodeGaussian( ); } 00096 00097 const char* GetType( ) const { 00098 00099 return "gaussian"; } 00100 }; 00101 00102 class CBayesNetFNNodeBeta : protected CBayesNetFNNode { 00103 protected: 00104 friend class CBayesNetFNNode; 00105 00106 static const size_t c_iMin = 0; 00107 static const size_t c_iMax = 1; 00108 static const size_t c_iAlpha = 2; 00109 static const size_t c_iBeta = 3; 00110 00111 void Randomize( ); 00112 bool Learn( const IDataset*, size_t, size_t ); 00113 bool Evaluate( float, std::vector<float>& ) const; 00114 00115 CBayesNetFNNode* New( DSL_node* pNode ) const { 00116 00117 return new CBayesNetFNNodeBeta( ); } 00118 00119 const char* GetType( ) const { 00120 00121 return "beta"; } 00122 }; 00123 00124 class CBayesNetFNNodeExponential : protected CBayesNetFNNode { 00125 protected: 00126 friend class CBayesNetFNNode; 00127 00128 static const size_t c_iMin = 0; 00129 static const size_t c_iBeta = 1; 00130 00131 void Randomize( ); 00132 bool Learn( const IDataset*, size_t, size_t ); 00133 bool Evaluate( float, std::vector<float>& ) const; 00134 00135 CBayesNetFNNode* New( DSL_node* pNode ) const { 00136 00137 return new CBayesNetFNNodeExponential( ); } 00138 00139 const char* GetType( ) const { 00140 00141 return "exponential"; } 00142 }; 00143 00144 class CBayesNetFNNodeMOG : protected CBayesNetFNNode { 00145 protected: 00146 friend class CBayesNetFNNode; 00147 00148 static const size_t c_iMu = 0; 00149 static const size_t c_iSigma = 1; 00150 00151 void Randomize( ); 00152 bool Learn( const IDataset*, size_t, size_t ); 00153 bool Evaluate( float, std::vector<float>& ) const; 00154 00155 CBayesNetFNNode* New( DSL_node* pNode ) const { 00156 00157 return new CBayesNetFNNodeMOG( ); } 00158 00159 const char* GetType( ) const { 00160 00161 return "mog"; } 00162 }; 00163 00164 class CBayesNetFNImpl : protected CBayesNetImpl { 00165 protected: 00166 CBayesNetFNImpl( ); 00167 ~CBayesNetFNImpl( ); 00168 00169 void Reset( ); 00170 bool Evaluate( const IDataset*, CDat*, std::vector<std::vector<float> >*, bool ) const; 00171 bool Evaluate( const IDataset*, size_t, size_t, bool, std::vector<float>& ) const; 00172 00173 size_t m_iNodes; 00174 CBayesNetFNNode** m_apNodes; 00175 bool m_fSmileNet; 00176 DSL_network m_SmileNet; 00177 }; 00178 00179 #endif // NO_SMILE 00180 00181 class CBayesNetMinimalNode { 00182 public: 00183 CBayesNetMinimalNode( ) : m_bDefault(0xFF) { } 00184 00185 unsigned char m_bDefault; 00186 CDataMatrix m_MatCPT; 00187 }; 00188 00189 class CBayesNetMinimalImpl : protected CBayesNetImpl, protected CFile { 00190 protected: 00191 static bool Counts2Probs( const std::vector<std::string>&, std::vector<float>&, float dAlpha = 1, 00192 float = HUGE_VAL, const CBayesNetMinimal* = NULL, size_t = 0, size_t = 0 ); 00193 00194 CBayesNetMinimalImpl( ) : CBayesNetImpl( true ), m_adNY(NULL) { } 00195 00196 std::string m_strID; 00197 long double* m_adNY; 00198 CDataMatrix m_MatRoot; 00199 std::vector<CBayesNetMinimalNode> m_vecNodes; 00200 }; 00201 00202 } 00203 00204 #endif // BAYESNETFNI_H