Sleipnir
src/bayesnetpnl.cpp
00001 /*****************************************************************************
00002 * This file is provided under the Creative Commons Attribution 3.0 license.
00003 *
00004 * You are free to share, copy, distribute, transmit, or adapt this work
00005 * PROVIDED THAT you attribute the work to the authors listed below.
00006 * For more information, please see the following web page:
00007 * http://creativecommons.org/licenses/by/3.0/
00008 *
00009 * This file is a component of the Sleipnir library for functional genomics,
00010 * authored by:
00011 * Curtis Huttenhower (chuttenh@princeton.edu)
00012 * Mark Schroeder
00013 * Maria D. Chikina
00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00015 *
00016 * If you use this library, the included executable tools, or any related
00017 * code in your work, please cite the following publication:
00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00019 * Olga G. Troyanskaya.
00020 * "The Sleipnir library for computational functional genomics"
00021 *****************************************************************************/
00022 #include "stdafx.h"
00023 #ifdef PNL_ENABLED
00024 #pragma warning (disable: 4244 4267)
00025 #include <pnl_dll.hpp>
00026 #pragma warning (default: 4244 4267)
00027 #include "bayesnet.h"
00028 #include "dat.h"
00029 #include "dataset.h"
00030 #include "meta.h"
00031 
00032 namespace Sleipnir {
00033 
00034 const char  CBayesNetPNLImpl::c_szBN[]  = "bn";
00035 
00043 CBayesNetPNL::CBayesNetPNL( bool fGroup ) : CBayesNetPNLImpl(fGroup) { }
00044 
00045 CBayesNetPNLImpl::CBayesNetPNLImpl( bool fGroup ) : CBayesNetImpl(fGroup),
00046     m_pPNLNet(NULL) { }
00047 
00048 CBayesNetPNLImpl::~CBayesNetPNLImpl( ) {
00049 
00050     if( m_pPNLNet )
00051         delete m_pPNLNet; }
00052 
00053 bool CBayesNetPNL::Open( const char* szFile ) {
00054     CContextPersistence ConPer;
00055 
00056     if( !ConPer.LoadXML( szFile ) )
00057         return false;
00058     if( m_pPNLNet )
00059         delete m_pPNLNet;
00060     return !!( m_pPNLNet = (CBNet*)ConPer.Get( c_szBN ) ); }
00061 
00062 bool CBayesNetPNL::Save( const char* szFile ) const {
00063     CContextPersistence ConPer;
00064 
00065     ConPer.Put( m_pPNLNet, c_szBN );
00066     return ConPer.SaveAsXML( szFile ); }
00067 
00068 bool CBayesNetPNL::Learn( const IDataset* pData, size_t iIterations, bool fZero, bool fELR ) {
00069     CEMLearningEngineDumb*  pLearner;
00070 
00071     if( !m_pPNLNet || fELR )
00072         return false;
00073 
00074     pLearner = CEMLearningEngineDumb::Create( m_pPNLNet );
00075     pLearner->SetMaxIterEM( (int)iIterations );
00076     pLearner->Learn( pData, fZero );
00077 
00078     delete pLearner;
00079     return true; }
00080 
00081 bool CBayesNetPNLImpl::IsContinuous( ) const {
00082 
00083     return ( m_pPNLNet ? !m_pPNLNet->GetNodeType( 0 )->IsDiscrete( ) : false ); }
00084 
00085 bool CBayesNetPNLImpl::Evaluate( const IDataset* pData, CDat* pDatOut,
00086     vector<vector<float> >* pvecvecdOut, bool fZero ) const {
00087     CInfEngine*                         pInferrer;
00088     size_t                              i, j, k, l, iVal;
00089     CEvidence*                          pEvidence;
00090     intVector                           veciObserved;
00091     valueVector                         vecValues;
00092     int                                 iNode;
00093     const CFactor*                      pFactor;
00094     const CMatrix<float>*               pMatrix;
00095     CMatrixIterator<float>*             pIter;
00096     float                               d;
00097     const float*                        pd;
00098     vector<float>*                      pvecdCur;
00099     map<string,float>                   mapData;
00100     map<string,float>::const_iterator   iterDatum;
00101     string                              strCur;
00102 
00103     if( !m_pPNLNet )
00104         return false;
00105 
00106     pvecdCur = NULL;
00107     pInferrer = CJtreeInfEngine::Create( m_pPNLNet );
00108     iNode = 0;
00109     for( i = 0; i < pData->GetGenes( ); ++i ) {
00110         if( !( i % 250 ) )
00111             g_CatSleipnir( ).notice( "CBayesNetPNL::Evaluate( %d ) %d/%d", fZero, i,
00112                 pData->GetGenes( ) );
00113         for( j = ( i + 1 ); j < pData->GetGenes( ); ++j ) {
00114             if( !pData->IsExample( i, j ) )
00115                 continue;
00116             if( m_fGroup ) {
00117                 strCur = EncodeDatum( pData, i, j );
00118                 if( ( iterDatum = mapData.find( strCur ) ) != mapData.end( ) ) {
00119                     if( pDatOut )
00120                         pDatOut->Set( i, j, iterDatum->second );
00121                     if( pvecvecdOut ) {
00122                         pvecvecdOut->resize( pvecvecdOut->size( ) + 1 );
00123                         (*pvecvecdOut)[ pvecvecdOut->size( ) - 1 ].push_back(
00124                             iterDatum->second ); }
00125                     continue; } }
00126 
00127             veciObserved.clear( );
00128             vecValues.clear( );
00129             for( k = 1; k < m_pPNLNet->GetNumberOfNodes( ); ++k ) {
00130                 if( pData->IsHidden( k ) )
00131                     continue;
00132                 if( IsContinuous( ) ) {
00133                     if( CMeta::IsNaN( d = pData->GetContinuous( i, j, k ) ) ) {
00134                         if( fZero )
00135                             d = 0;
00136                         else
00137                             continue; }
00138                     vecValues.resize( vecValues.size( ) + 1 );
00139                     vecValues[ vecValues.size( ) - 1 ].SetFlt( d ); }
00140                 else {
00141                     if( ( iVal = pData->GetDiscrete( i, j, k ) ) == -1 ) {
00142                         if( fZero )
00143                             iVal = 0;
00144                         else
00145                             continue; }
00146                     vecValues.resize( vecValues.size( ) + 1 );
00147                     vecValues[ vecValues.size( ) - 1 ].SetInt( (int)iVal ); }
00148                 veciObserved.push_back( (int)k ); }
00149 
00150             pEvidence = CEvidence::Create( m_pPNLNet, veciObserved, vecValues );
00151             pInferrer->EnterEvidence( pEvidence );
00152             pInferrer->MarginalNodes( &iNode, 1 );
00153             pFactor = pInferrer->GetQueryJPD( );
00154             delete pEvidence;
00155 
00156             if( pvecvecdOut ) {
00157                 pvecvecdOut->resize( pvecvecdOut->size( ) + 1 );
00158                 pvecdCur = &(*pvecvecdOut)[ pvecvecdOut->size( ) - 1 ]; }
00159             if( pFactor->GetDistributionType( ) == dtTabular ) {
00160                 pMatrix = pFactor->GetMatrix( matTable );
00161                 pIter = pMatrix->InitIterator( );
00162                 while( true ) {
00163                     pd = pMatrix->Value( pIter );
00164                     pMatrix->Next( pIter );
00165                     if( !pMatrix->IsValueHere( pIter ) )
00166                         break;
00167                     mapData[ strCur ] = *pd;
00168                     if( pvecdCur )
00169                         pvecdCur->push_back( *pd );
00170                     if( pDatOut ) {
00171                         pDatOut->Set( i, j, *pd );
00172                         break; } }
00173                 delete pIter; }
00174             else {
00175                 pMatrix = pFactor->GetMatrix( matMean );
00176                 for( pIter = pMatrix->InitIterator( ); pMatrix->IsValueHere( pIter );
00177                     pMatrix->Next( pIter ) ) {
00178                     mapData[ strCur ] = *pMatrix->Value( pIter );
00179                     if( pvecdCur )
00180                         pvecdCur->push_back( *pMatrix->Value( pIter ) );
00181                     if( pDatOut ) {
00182                         pDatOut->Set( i, j, *pMatrix->Value( pIter ) );
00183                         break; } }
00184                 delete pIter;
00185                 if( !pvecdCur )
00186                     break;
00187                 pMatrix = pFactor->GetMatrix( matCovariance );
00188                 for( pIter = pMatrix->InitIterator( ); pMatrix->IsValueHere( pIter );
00189                     pMatrix->Next( pIter ) )
00190                     pvecdCur->push_back( *pMatrix->Value( pIter ) );
00191                 delete pIter;
00192 
00193                 veciObserved.clear( );
00194                 pFactor->GetDomain( &veciObserved );
00195                 for( l = k = 0; k < veciObserved.size( ); ++k )
00196                     l += m_pPNLNet->GetGraph( )->GetNumberOfParents( veciObserved[ k ] );
00197                 for( k = 0; k < l; ++k ) {
00198                     pMatrix = pFactor->GetMatrix( matWeights, (int)k );
00199                     for( pIter = pMatrix->InitIterator( ); pMatrix->IsValueHere( pIter );
00200                         pMatrix->Next( pIter ) ) 
00201                         pvecdCur->push_back( *pMatrix->Value( pIter ) );
00202                     delete pIter; } } } }
00203 
00204     delete pInferrer;
00205     return true; }
00206 
00207 }
00208 
00209 #endif // PNL_ENABLED