Sleipnir
|
00001 /***************************************************************************** 00002 * This file is provided under the Creative Commons Attribution 3.0 license. 00003 * 00004 * You are free to share, copy, distribute, transmit, or adapt this work 00005 * PROVIDED THAT you attribute the work to the authors listed below. 00006 * For more information, please see the following web page: 00007 * http://creativecommons.org/licenses/by/3.0/ 00008 * 00009 * This file is a component of the Sleipnir library for functional genomics, 00010 * authored by: 00011 * Curtis Huttenhower (chuttenh@princeton.edu) 00012 * Mark Schroeder 00013 * Maria D. Chikina 00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact) 00015 * 00016 * If you use this library, the included executable tools, or any related 00017 * code in your work, please cite the following publication: 00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and 00019 * Olga G. Troyanskaya. 00020 * "The Sleipnir library for computational functional genomics" 00021 *****************************************************************************/ 00022 #include "stdafx.h" 00023 #ifdef PNL_ENABLED 00024 #pragma warning (disable: 4244 4267) 00025 #include <pnl_dll.hpp> 00026 #pragma warning (default: 4244 4267) 00027 #include "bayesnet.h" 00028 #include "dat.h" 00029 #include "dataset.h" 00030 #include "meta.h" 00031 00032 namespace Sleipnir { 00033 00034 const char CBayesNetPNLImpl::c_szBN[] = "bn"; 00035 00043 CBayesNetPNL::CBayesNetPNL( bool fGroup ) : CBayesNetPNLImpl(fGroup) { } 00044 00045 CBayesNetPNLImpl::CBayesNetPNLImpl( bool fGroup ) : CBayesNetImpl(fGroup), 00046 m_pPNLNet(NULL) { } 00047 00048 CBayesNetPNLImpl::~CBayesNetPNLImpl( ) { 00049 00050 if( m_pPNLNet ) 00051 delete m_pPNLNet; } 00052 00053 bool CBayesNetPNL::Open( const char* szFile ) { 00054 CContextPersistence ConPer; 00055 00056 if( !ConPer.LoadXML( szFile ) ) 00057 return false; 00058 if( m_pPNLNet ) 00059 delete m_pPNLNet; 00060 return !!( m_pPNLNet = (CBNet*)ConPer.Get( c_szBN ) ); } 00061 00062 bool CBayesNetPNL::Save( const char* szFile ) const { 00063 CContextPersistence ConPer; 00064 00065 ConPer.Put( m_pPNLNet, c_szBN ); 00066 return ConPer.SaveAsXML( szFile ); } 00067 00068 bool CBayesNetPNL::Learn( const IDataset* pData, size_t iIterations, bool fZero, bool fELR ) { 00069 CEMLearningEngineDumb* pLearner; 00070 00071 if( !m_pPNLNet || fELR ) 00072 return false; 00073 00074 pLearner = CEMLearningEngineDumb::Create( m_pPNLNet ); 00075 pLearner->SetMaxIterEM( (int)iIterations ); 00076 pLearner->Learn( pData, fZero ); 00077 00078 delete pLearner; 00079 return true; } 00080 00081 bool CBayesNetPNLImpl::IsContinuous( ) const { 00082 00083 return ( m_pPNLNet ? !m_pPNLNet->GetNodeType( 0 )->IsDiscrete( ) : false ); } 00084 00085 bool CBayesNetPNLImpl::Evaluate( const IDataset* pData, CDat* pDatOut, 00086 vector<vector<float> >* pvecvecdOut, bool fZero ) const { 00087 CInfEngine* pInferrer; 00088 size_t i, j, k, l, iVal; 00089 CEvidence* pEvidence; 00090 intVector veciObserved; 00091 valueVector vecValues; 00092 int iNode; 00093 const CFactor* pFactor; 00094 const CMatrix<float>* pMatrix; 00095 CMatrixIterator<float>* pIter; 00096 float d; 00097 const float* pd; 00098 vector<float>* pvecdCur; 00099 map<string,float> mapData; 00100 map<string,float>::const_iterator iterDatum; 00101 string strCur; 00102 00103 if( !m_pPNLNet ) 00104 return false; 00105 00106 pvecdCur = NULL; 00107 pInferrer = CJtreeInfEngine::Create( m_pPNLNet ); 00108 iNode = 0; 00109 for( i = 0; i < pData->GetGenes( ); ++i ) { 00110 if( !( i % 250 ) ) 00111 g_CatSleipnir( ).notice( "CBayesNetPNL::Evaluate( %d ) %d/%d", fZero, i, 00112 pData->GetGenes( ) ); 00113 for( j = ( i + 1 ); j < pData->GetGenes( ); ++j ) { 00114 if( !pData->IsExample( i, j ) ) 00115 continue; 00116 if( m_fGroup ) { 00117 strCur = EncodeDatum( pData, i, j ); 00118 if( ( iterDatum = mapData.find( strCur ) ) != mapData.end( ) ) { 00119 if( pDatOut ) 00120 pDatOut->Set( i, j, iterDatum->second ); 00121 if( pvecvecdOut ) { 00122 pvecvecdOut->resize( pvecvecdOut->size( ) + 1 ); 00123 (*pvecvecdOut)[ pvecvecdOut->size( ) - 1 ].push_back( 00124 iterDatum->second ); } 00125 continue; } } 00126 00127 veciObserved.clear( ); 00128 vecValues.clear( ); 00129 for( k = 1; k < m_pPNLNet->GetNumberOfNodes( ); ++k ) { 00130 if( pData->IsHidden( k ) ) 00131 continue; 00132 if( IsContinuous( ) ) { 00133 if( CMeta::IsNaN( d = pData->GetContinuous( i, j, k ) ) ) { 00134 if( fZero ) 00135 d = 0; 00136 else 00137 continue; } 00138 vecValues.resize( vecValues.size( ) + 1 ); 00139 vecValues[ vecValues.size( ) - 1 ].SetFlt( d ); } 00140 else { 00141 if( ( iVal = pData->GetDiscrete( i, j, k ) ) == -1 ) { 00142 if( fZero ) 00143 iVal = 0; 00144 else 00145 continue; } 00146 vecValues.resize( vecValues.size( ) + 1 ); 00147 vecValues[ vecValues.size( ) - 1 ].SetInt( (int)iVal ); } 00148 veciObserved.push_back( (int)k ); } 00149 00150 pEvidence = CEvidence::Create( m_pPNLNet, veciObserved, vecValues ); 00151 pInferrer->EnterEvidence( pEvidence ); 00152 pInferrer->MarginalNodes( &iNode, 1 ); 00153 pFactor = pInferrer->GetQueryJPD( ); 00154 delete pEvidence; 00155 00156 if( pvecvecdOut ) { 00157 pvecvecdOut->resize( pvecvecdOut->size( ) + 1 ); 00158 pvecdCur = &(*pvecvecdOut)[ pvecvecdOut->size( ) - 1 ]; } 00159 if( pFactor->GetDistributionType( ) == dtTabular ) { 00160 pMatrix = pFactor->GetMatrix( matTable ); 00161 pIter = pMatrix->InitIterator( ); 00162 while( true ) { 00163 pd = pMatrix->Value( pIter ); 00164 pMatrix->Next( pIter ); 00165 if( !pMatrix->IsValueHere( pIter ) ) 00166 break; 00167 mapData[ strCur ] = *pd; 00168 if( pvecdCur ) 00169 pvecdCur->push_back( *pd ); 00170 if( pDatOut ) { 00171 pDatOut->Set( i, j, *pd ); 00172 break; } } 00173 delete pIter; } 00174 else { 00175 pMatrix = pFactor->GetMatrix( matMean ); 00176 for( pIter = pMatrix->InitIterator( ); pMatrix->IsValueHere( pIter ); 00177 pMatrix->Next( pIter ) ) { 00178 mapData[ strCur ] = *pMatrix->Value( pIter ); 00179 if( pvecdCur ) 00180 pvecdCur->push_back( *pMatrix->Value( pIter ) ); 00181 if( pDatOut ) { 00182 pDatOut->Set( i, j, *pMatrix->Value( pIter ) ); 00183 break; } } 00184 delete pIter; 00185 if( !pvecdCur ) 00186 break; 00187 pMatrix = pFactor->GetMatrix( matCovariance ); 00188 for( pIter = pMatrix->InitIterator( ); pMatrix->IsValueHere( pIter ); 00189 pMatrix->Next( pIter ) ) 00190 pvecdCur->push_back( *pMatrix->Value( pIter ) ); 00191 delete pIter; 00192 00193 veciObserved.clear( ); 00194 pFactor->GetDomain( &veciObserved ); 00195 for( l = k = 0; k < veciObserved.size( ); ++k ) 00196 l += m_pPNLNet->GetGraph( )->GetNumberOfParents( veciObserved[ k ] ); 00197 for( k = 0; k < l; ++k ) { 00198 pMatrix = pFactor->GetMatrix( matWeights, (int)k ); 00199 for( pIter = pMatrix->InitIterator( ); pMatrix->IsValueHere( pIter ); 00200 pMatrix->Next( pIter ) ) 00201 pvecdCur->push_back( *pMatrix->Value( pIter ) ); 00202 delete pIter; } } } } 00203 00204 delete pInferrer; 00205 return true; } 00206 00207 } 00208 00209 #endif // PNL_ENABLED