Sleipnir
src/bayesnetfni.h
00001 /*****************************************************************************
00002 * This file is provided under the Creative Commons Attribution 3.0 license.
00003 *
00004 * You are free to share, copy, distribute, transmit, or adapt this work
00005 * PROVIDED THAT you attribute the work to the authors listed below.
00006 * For more information, please see the following web page:
00007 * http://creativecommons.org/licenses/by/3.0/
00008 *
00009 * This file is a component of the Sleipnir library for functional genomics,
00010 * authored by:
00011 * Curtis Huttenhower (chuttenh@princeton.edu)
00012 * Mark Schroeder
00013 * Maria D. Chikina
00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00015 *
00016 * If you use this library, the included executable tools, or any related
00017 * code in your work, please cite the following publication:
00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00019 * Olga G. Troyanskaya.
00020 * "The Sleipnir library for computational functional genomics"
00021 *****************************************************************************/
00022 #ifndef BAYESNETFNI_H
00023 #define BAYESNETFNI_H
00024 
00025 namespace Sleipnir {
00026 
00027 class CBayesNetMinimal;
00028 
00029 #ifndef NO_SMILE
00030 
00031 class CBayesNetFNNode {
00032 protected:
00033     friend class CBayesNetFN;
00034     friend class CBayesNetFNImpl;
00035 
00036     static const char   c_szType[];
00037 
00038     static CBayesNetFNNode* Open( DSL_node* );
00039 
00040     const std::string& GetName( ) const;
00041     unsigned char GetParameters( ) const;
00042     void Reverse( );
00043     bool Save( DSL_node* ) const;
00044     bool Learn( const std::vector<size_t>& );
00045 
00046     virtual const char* GetType( ) const = 0;
00047     virtual void Randomize( ) = 0;
00048     virtual CBayesNetFNNode* New( DSL_node* ) const = 0;
00049     virtual bool Learn( const IDataset*, size_t, size_t ) = 0;
00050     virtual bool Evaluate( float, std::vector<float>& ) const = 0;
00051 
00052     virtual bool IsContinuous( ) const {
00053 
00054         return true; }
00055 
00056     std::string         m_strName;
00057     const char*         m_szType;
00058     CFullMatrix<float>  m_Params;
00059 };
00060 
00061 class CBayesNetFNNodeDiscrete : protected CBayesNetFNNode {
00062 protected:
00063     friend class CBayesNetFNNode;
00064 
00065     void Randomize( );
00066     bool Learn( const IDataset*, size_t, size_t );
00067     bool Evaluate( float, std::vector<float>& ) const;
00068 
00069     CBayesNetFNNode* New( DSL_node* pNode ) const {
00070 
00071         return new CBayesNetFNNodeDiscrete( ); }
00072 
00073     const char* GetType( ) const {
00074 
00075         return "discrete"; }
00076 
00077     bool IsContinuous( ) const {
00078 
00079         return false; }
00080 };
00081 
00082 class CBayesNetFNNodeGaussian : protected CBayesNetFNNode {
00083 protected:
00084     friend class CBayesNetFNNode;
00085 
00086     static const size_t c_iMu       = 0;
00087     static const size_t c_iSigma    = 1;
00088 
00089     void Randomize( );
00090     bool Learn( const IDataset*, size_t, size_t );
00091     bool Evaluate( float, std::vector<float>& ) const;
00092 
00093     CBayesNetFNNode* New( DSL_node* pNode ) const {
00094 
00095         return new CBayesNetFNNodeGaussian( ); }
00096 
00097     const char* GetType( ) const {
00098 
00099         return "gaussian"; }
00100 };
00101 
00102 class CBayesNetFNNodeBeta : protected CBayesNetFNNode {
00103 protected:
00104     friend class CBayesNetFNNode;
00105 
00106     static const size_t c_iMin      = 0;
00107     static const size_t c_iMax      = 1;
00108     static const size_t c_iAlpha    = 2;
00109     static const size_t c_iBeta     = 3;
00110 
00111     void Randomize( );
00112     bool Learn( const IDataset*, size_t, size_t );
00113     bool Evaluate( float, std::vector<float>& ) const;
00114 
00115     CBayesNetFNNode* New( DSL_node* pNode ) const {
00116 
00117         return new CBayesNetFNNodeBeta( ); }
00118 
00119     const char* GetType( ) const {
00120 
00121         return "beta"; }
00122 };
00123 
00124 class CBayesNetFNNodeExponential : protected CBayesNetFNNode {
00125 protected:
00126     friend class CBayesNetFNNode;
00127 
00128     static const size_t c_iMin  = 0;
00129     static const size_t c_iBeta = 1;
00130 
00131     void Randomize( );
00132     bool Learn( const IDataset*, size_t, size_t );
00133     bool Evaluate( float, std::vector<float>& ) const;
00134 
00135     CBayesNetFNNode* New( DSL_node* pNode ) const {
00136 
00137         return new CBayesNetFNNodeExponential( ); }
00138 
00139     const char* GetType( ) const {
00140 
00141         return "exponential"; }
00142 };
00143 
00144 class CBayesNetFNNodeMOG : protected CBayesNetFNNode {
00145 protected:
00146     friend class CBayesNetFNNode;
00147 
00148     static const size_t c_iMu       = 0;
00149     static const size_t c_iSigma    = 1;
00150 
00151     void Randomize( );
00152     bool Learn( const IDataset*, size_t, size_t );
00153     bool Evaluate( float, std::vector<float>& ) const;
00154 
00155     CBayesNetFNNode* New( DSL_node* pNode ) const {
00156 
00157         return new CBayesNetFNNodeMOG( ); }
00158 
00159     const char* GetType( ) const {
00160 
00161         return "mog"; }
00162 };
00163 
00164 class CBayesNetFNImpl : protected CBayesNetImpl {
00165 protected:
00166     CBayesNetFNImpl( );
00167     ~CBayesNetFNImpl( );
00168 
00169     void Reset( );
00170     bool Evaluate( const IDataset*, CDat*, std::vector<std::vector<float> >*, bool ) const;
00171     bool Evaluate( const IDataset*, size_t, size_t, bool, std::vector<float>& ) const;
00172 
00173     size_t              m_iNodes;
00174     CBayesNetFNNode**   m_apNodes;
00175     bool                m_fSmileNet;
00176     DSL_network         m_SmileNet;
00177 };
00178 
00179 #endif // NO_SMILE
00180 
00181 class CBayesNetMinimalNode {
00182 public:
00183     CBayesNetMinimalNode( ) : m_bDefault(0xFF) { }
00184 
00185     unsigned char   m_bDefault;
00186     CDataMatrix     m_MatCPT;
00187 };
00188 
00189 class CBayesNetMinimalImpl : protected CBayesNetImpl, protected CFile {
00190 protected:
00191     static bool Counts2Probs( const std::vector<std::string>&, std::vector<float>&, float dAlpha = 1,
00192         float = HUGE_VAL, const CBayesNetMinimal* = NULL, size_t = 0, size_t = 0 );
00193 
00194     CBayesNetMinimalImpl( ) : CBayesNetImpl( true ), m_adNY(NULL) { }
00195 
00196     std::string                         m_strID;
00197     long double*                        m_adNY;
00198     CDataMatrix                         m_MatRoot;
00199     std::vector<CBayesNetMinimalNode>   m_vecNodes;
00200 };
00201 
00202 }
00203 
00204 #endif // BAYESNETFNI_H