Sleipnir
tools/BNCreator/BNCreator.cpp
00001 /*****************************************************************************
00002 * This file is provided under the Creative Commons Attribution 3.0 license.
00003 *
00004 * You are free to share, copy, distribute, transmit, or adapt this work
00005 * PROVIDED THAT you attribute the work to the authors listed below.
00006 * For more information, please see the following web page:
00007 * http://creativecommons.org/licenses/by/3.0/
00008 *
00009 * This file is a component of the Sleipnir library for functional genomics,
00010 * authored by:
00011 * Curtis Huttenhower (chuttenh@princeton.edu)
00012 * Mark Schroeder
00013 * Maria D. Chikina
00014 * Olga G. Troyanskaya (ogt@princeton.edu, primary contact)
00015 *
00016 * If you use this library, the included executable tools, or any related
00017 * code in your work, please cite the following publication:
00018 * Curtis Huttenhower, Mark Schroeder, Maria D. Chikina, and
00019 * Olga G. Troyanskaya.
00020 * "The Sleipnir library for computational functional genomics"
00021 *****************************************************************************/
00022 #include "stdafx.h"
00023 #include "cmdline.h"
00024 
00025 static const char   c_acDab[]   = ".dab";
00026 
00027 struct STerm {
00028     string                  m_strInput;
00029     string                  m_strOutput;
00030     CGenes*                 m_pGenes;
00031     CBayesNetSmile          m_BNRoot;
00032     vector<CBayesNetSmile*> m_vecpBNs;
00033 
00034     STerm( size_t iNodes, const string& strInput, const string& strOutput ) : m_strInput(strInput),
00035         m_strOutput(strOutput), m_pGenes(NULL) {
00036 
00037         m_vecpBNs.resize( iNodes ); }
00038 
00039     ~STerm( ) {
00040         size_t  i;
00041 
00042         if( m_pGenes )
00043             delete m_pGenes;
00044         for( i = 0; i < m_vecpBNs.size( ); ++i )
00045             delete m_vecpBNs[ i ]; }
00046 
00047     bool Open( CGenome& Genome ) {
00048         ifstream    ifsm;
00049 
00050         m_pGenes = new CGenes( Genome );
00051         ifsm.open( m_strInput.c_str( ) );
00052         return m_pGenes->Open( ifsm ); }
00053 
00054     bool LearnRoot( const CDataPair& Answers, const IDataset* pData,
00055         const CBayesNetSmile* pBNDefault ) {
00056         CDataFilter     Data;
00057         vector<string>  vecstrDummy;
00058         vector<size_t>  veciZeros;
00059 
00060         if( m_pGenes ) {
00061             Data.Attach( pData, *m_pGenes, CDat::EFilterTerm, &Answers );
00062             pData = &Data; }
00063         vecstrDummy.push_back( "FR" );
00064         if( !m_BNRoot.Open( pData, vecstrDummy, veciZeros ) ) {
00065             cerr << "Couldn't create base network (" << m_strInput << ')' << endl;
00066             return false; }
00067         if( pBNDefault )
00068             m_BNRoot.SetDefault( *pBNDefault );
00069         if( !m_BNRoot.Learn( pData, 1 ) ) {
00070             cerr << "Couldn't learn base network (" << m_strInput << ')' << endl;
00071             return false; }
00072 
00073         return true; }
00074 
00075     bool LearnNode( size_t iNode, const CDataPair& Answers, const IDataset* pData, const vector<string>& vecstrNames,
00076         const vector<size_t>& veciZeros, const CBayesNetSmile* pBNDefault, bool fZero ) {
00077         CDataFilter     Data;
00078         CBayesNetSmile* pBN;
00079 
00080         if( m_pGenes ) {
00081             Data.Attach( pData, *m_pGenes, CDat::EFilterTerm, &Answers );
00082             pData = &Data; }
00083         m_vecpBNs[iNode] = pBN = new CBayesNetSmile( );
00084         if( !pBN->Open( pData, vecstrNames, veciZeros ) ) {
00085             cerr << "Couldn't create network for (" << m_strInput << "): " << vecstrNames[ 1 ] << endl;
00086             return false; }
00087         if( pBNDefault )
00088             pBN->SetDefault( *pBNDefault );
00089         if( !pBN->Learn( pData, 1, fZero ) ) {
00090             cerr << "Couldn't learn network for (" << m_strInput << "): " << vecstrNames[ 1 ] << endl;
00091             return false; }
00092 
00093         return true; }
00094 
00095     bool Save( ) const {
00096         CBayesNetSmile  BNOut;
00097 
00098         if( !BNOut.Open( m_BNRoot, m_vecpBNs ) ) {
00099             cerr << "Couldn't merge networks (" << m_strInput << ')' << endl;
00100             return false; }
00101         BNOut.Save( m_strOutput.c_str( ) );
00102         return true; }
00103 };
00104 
00105 struct SLearn {
00106     size_t                      m_iNode;
00107     string                      m_strInput;
00108     const CDataPair*            m_pAnswers;
00109     const gengetopt_args_info*  m_psArgs;
00110     const map<string,size_t>*   m_pmapZeros;
00111     const CBayesNetSmile*       m_pBNDefault;
00112     vector<STerm*>*             m_pvecpsOutputs;
00113 };
00114 
00115 void* learn( void* );
00116 
00117 int main( int iArgs, char** aszArgs ) {
00118     gengetopt_args_info sArgs;
00119     size_t              i;
00120     map<string,size_t>  mapZeros;
00121     CBayesNetSmile      BNIn;
00122     vector<string>      vecstrNames;
00123 
00124 #ifdef WIN32
00125     pthread_win32_process_attach_np( );
00126 #endif // WIN32
00127     if( cmdline_parser( iArgs, aszArgs, &sArgs ) ) {
00128         cmdline_parser_print_help( );
00129         return 1; }
00130     CMeta Meta( sArgs.verbosity_arg );
00131 
00132     if( sArgs.zeros_arg ) {
00133         ifstream        ifsm;
00134         vector<string>  vecstrZeros;
00135         char            acLine[ 1024 ];
00136 
00137         ifsm.open( sArgs.zeros_arg );
00138         if( !ifsm.is_open( ) ) {
00139             cerr << "Couldn't open: " << sArgs.zeros_arg << endl;
00140             return 1; }
00141         while( !ifsm.eof( ) ) {
00142             ifsm.getline( acLine, ARRAYSIZE(acLine) - 1 );
00143             acLine[ ARRAYSIZE(acLine) - 1 ] = 0;
00144             vecstrZeros.clear( );
00145             CMeta::Tokenize( acLine, vecstrZeros );
00146             if( vecstrZeros.empty( ) )
00147                 continue;
00148             mapZeros[ vecstrZeros[ 0 ] ] = atoi( vecstrZeros[ 1 ].c_str( ) ); } }
00149 
00150     if( sArgs.input_arg ) {
00151         vector<string>  vecstrFiles, vecstrGenes;
00152         CDat            DatYes, DatNo;
00153         CDatasetCompact Data;
00154         vector<size_t>  veciGenes;
00155         size_t          j, k, iOne, iTwo, iBin, iZero, iIndex;
00156         double          d;
00157         CDataMatrix     MatCPT;
00158         ofstream        ofsm;
00159         char            acTemp[ L_tmpnam + 1 ];
00160         const char*     szTemp;
00161         float*          adYes;
00162         float*          adNo;
00163         CGenome         Genome;
00164         CGenes          GenesIn( Genome ), GenesEx( Genome );
00165 
00166         if( !BNIn.Open( sArgs.input_arg ) ) {
00167             cerr << "Couldn't open: " << sArgs.input_arg << endl;
00168             return 1; }
00169 
00170         BNIn.GetNodes( vecstrFiles );
00171         vecstrFiles.erase( vecstrFiles.begin( ) );
00172         for( i = 0; i < vecstrFiles.size( ); ++i )
00173             vecstrFiles[ i ] = (string)sArgs.directory_arg + '/' + vecstrFiles[ i ] + c_acDab;
00174         if( !Data.OpenGenes( vecstrFiles ) ) {
00175             cerr << "Couldn't open: " << sArgs.directory_arg << endl;
00176             return 1; }
00177         if( sArgs.genes_arg ) {
00178             ifstream    ifsm;
00179 
00180             ifsm.open( sArgs.genes_arg );
00181             if( !GenesIn.Open( ifsm ) ) {
00182                 cerr << "Couldn't open: " << sArgs.genes_arg << endl;
00183                 return 1; }
00184             ifsm.close( );
00185             for( i = 0; i < GenesIn.GetGenes( ); ++i )
00186                 vecstrGenes.push_back( GenesIn.GetGene( i ).GetName( ) ); }
00187 
00188 #pragma warning( disable : 4996 )
00189         if( !( szTemp = tmpnam( acTemp ) ) ) {
00190             cerr << "Couldn't create temp file: " << acTemp << endl;
00191             return 1; }
00192 #pragma warning( default : 4996 )
00193         DatYes.Open( vecstrGenes.size( ) ? vecstrGenes : Data.GetGeneNames( ), false, sArgs.output_arg );
00194         DatNo.Open( DatYes.GetGeneNames( ), false ); // , szTemp );
00195         adYes = new float[ DatYes.GetGenes( ) ];
00196         adNo = new float[ DatNo.GetGenes( ) ];
00197         BNIn.GetCPT( 0, MatCPT );
00198         d = log( MatCPT.Get( 0, 0 ) );
00199         for( i = 0; i < DatNo.GetGenes( ); ++i )
00200             adNo[ i ] = (float)d;
00201         for( i = 0; i < DatNo.GetGenes( ); ++i )
00202             DatNo.Set( i, adNo );
00203         d = log( MatCPT.Get( 1, 0 ) );
00204         for( i = 0; i < DatYes.GetGenes( ); ++i )
00205             adYes[ i ] = (float)d;
00206         for( i = 0; i < DatYes.GetGenes( ); ++i )
00207             DatYes.Set( i, adYes );
00208 
00209         BNIn.GetNodes( vecstrNames );
00210         veciGenes.resize( DatYes.GetGenes( ) );
00211         for( i = 0; i < vecstrFiles.size( ); ++i ) {
00212             vector<string>                      vecstrDatum;
00213             map<string,size_t>::const_iterator  iterZero;
00214 
00215             vecstrDatum.push_back( vecstrFiles[ i ] );
00216             if( !Data.Open( vecstrDatum, !!sArgs.memmap_flag ) ) {
00217                 cerr << "Couldn't open: " << vecstrFiles[ i ] << endl;
00218                 return 1; }
00219             iZero = ( ( iterZero = mapZeros.find( vecstrNames[ i + 1 ] ) ) == mapZeros.end( ) ) ? -1 :
00220                     iterZero->second;
00221             BNIn.GetCPT( i + 1, MatCPT );
00222             for( j = 0; j < veciGenes.size( ); ++j )
00223                 veciGenes[ j ] = Data.GetGene( DatYes.GetGene( j ) );
00224             for( j = 0; j < DatYes.GetGenes( ); ++j ) {
00225                 iBin = -1;
00226                 if( ( ( iOne = veciGenes[ j ] ) == -1 ) && ( iZero == -1 ) )
00227                     continue;
00228                 memcpy( adYes, DatYes.Get( j ), ( DatYes.GetGenes( ) - j - 1 ) * sizeof(*adYes) );
00229                 memcpy( adNo, DatNo.Get( j ), ( DatNo.GetGenes( ) - j - 1 ) * sizeof(*adNo) );
00230                 for( k = ( j + 1 ); k < DatYes.GetGenes( ); ++k ) {
00231                     if( ( iOne == -1 ) || ( ( iTwo = veciGenes[ k ] ) == -1 ) ||
00232                         !Data.IsExample( iOne, iTwo ) ||
00233                         ( ( iBin = Data.GetDiscrete( iOne, iTwo, 0 ) ) == -1 ) )
00234                         iBin = iZero;
00235                     if( iBin == -1 )
00236                         continue;
00237                     adNo[ iIndex = ( k - j - 1 ) ] += log( MatCPT.Get( iBin, 0 ) );
00238                     adYes[ iIndex ] += log( MatCPT.Get( iBin, 1 ) ); }
00239                 DatYes.Set( j, adYes );
00240                 DatNo.Set( j, adNo ); } }
00241         for( i = 0; i < DatYes.GetGenes( ); ++i ) {
00242             memcpy( adYes, DatYes.Get( i ), ( DatYes.GetGenes( ) - i - 1 ) * sizeof(*adYes) );
00243             memcpy( adNo, DatNo.Get( i ), ( DatNo.GetGenes( ) - i - 1 ) * sizeof(*adNo) );
00244             for( j = 0; j < ( DatYes.GetGenes( ) - i - 1 ); ++j )
00245                 adYes[ j ] = (float)( 1 / ( 1 + exp( (double)adNo[ j ] - (double)adYes[ j ] ) ) );
00246             DatYes.Set( i, adYes ); }
00247         _unlink( szTemp ); }
00248     else {
00249         size_t              iArg, iThread;
00250         CDataPair           Answers;
00251         CBayesNetSmile      BNDefault;
00252         CGenome             Genome;
00253         vector<STerm*>      vecpsOutputs;
00254         vector<string>      vecstrDummy;
00255         CDatasetCompact     Data;
00256         vector<pthread_t>   vecpthdThreads;
00257         vector<SLearn>      vecsData;
00258 
00259         if( sArgs.default_arg && !BNDefault.Open( sArgs.default_arg ) ) {
00260             cerr << "Couldn't open: " << sArgs.default_arg << endl;
00261             return 1; }
00262 
00263         if( !Answers.Open( sArgs.answers_arg, false ) ) {
00264             cerr << "Couldn't open: " << sArgs.answers_arg << endl;
00265             return 1; }
00266         if( sArgs.genes_arg && !Answers.FilterGenes( sArgs.genes_arg, CDat::EFilterInclude ) ) {
00267             cerr << "Couldn't open: " << sArgs.genes_arg << endl;
00268             return 1; }
00269         if( sArgs.genet_arg && !Answers.FilterGenes( sArgs.genet_arg, CDat::EFilterTerm ) ) {
00270             cerr << "Couldn't open: " << sArgs.genet_arg << endl;
00271             return 1; }
00272         if( sArgs.genee_arg && !Answers.FilterGenes( sArgs.genee_arg, CDat::EFilterEdge ) ) {
00273             cerr << "Couldn't open: " << sArgs.genee_arg << endl;
00274             return 1; }
00275         if( sArgs.genex_arg && !Answers.FilterGenes( sArgs.genex_arg, CDat::EFilterExclude ) ) {
00276             cerr << "Couldn't open: " << sArgs.genex_arg << endl;
00277             return 1; }
00278 
00279         if( sArgs.terms_arg ) {
00280             string          strFile;
00281 
00282             FOR_EACH_DIRECTORY_FILE((string)sArgs.terms_arg, strFile)
00283                 if( strFile[ 0 ] == '.' )
00284                     continue;
00285 
00286                 vecpsOutputs.push_back( new STerm( sArgs.inputs_num, (string)sArgs.terms_arg + '/' + strFile,
00287                     (string)sArgs.output_arg + '/' + strFile + ".xdsl" ) );
00288                 if( !vecpsOutputs[ vecpsOutputs.size( ) - 1 ]->Open( Genome ) ) {
00289                     cerr << "Could not open: " << strFile << endl;
00290                     return 1; } } }
00291         else
00292             vecpsOutputs.push_back( new STerm( sArgs.inputs_num, "", sArgs.output_arg ) );
00293 
00294         if( !Data.Open( Answers, vecstrDummy ) ) {
00295             cerr << "Couldn't open answer set" << endl;
00296             return 1; }
00297         for( i = 0; i < vecpsOutputs.size( ); ++i ) {
00298             if( !( i % 50 ) )
00299                 cerr << "Term " << i << '/' << vecpsOutputs.size( ) << endl;
00300             if( !vecpsOutputs[ i ]->LearnRoot( Answers, &Data, sArgs.default_arg ? &BNDefault : NULL ) )
00301                 return 1; }
00302 
00303         vecpthdThreads.resize( sArgs.inputs_num );
00304         vecsData.resize( vecpthdThreads.size( ) );
00305         for( iArg = 0; iArg < sArgs.inputs_num; iArg += iThread ) {
00306             for( iThread = 0; ( ( sArgs.threads_arg == -1 ) || ( iThread < (size_t)sArgs.threads_arg ) ) &&
00307                 ( ( iArg + iThread ) < sArgs.inputs_num ); ++iThread ) {
00308                 i = iArg + iThread;
00309                 vecsData[ i ].m_iNode = i;
00310                 vecsData[ i ].m_strInput = sArgs.inputs[ i ];
00311                 vecsData[ i ].m_pAnswers = &Answers;
00312                 vecsData[ i ].m_psArgs = &sArgs;
00313                 vecsData[ i ].m_pmapZeros = &mapZeros;
00314                 vecsData[ i ].m_pBNDefault = &BNDefault;
00315                 vecsData[ i ].m_pvecpsOutputs = &vecpsOutputs;
00316                 if( pthread_create( &vecpthdThreads[ i ], NULL, learn, &vecsData[ i ] ) ) {
00317                     cerr << "Couldn't create thread: " << sArgs.inputs[ i ] << endl;
00318                     return 1; } }
00319             for( i = 0; i < iThread; ++i )
00320                 pthread_join( vecpthdThreads[ iArg + i ], NULL ); }
00321 
00322         for( i = 0; i < vecpsOutputs.size( ); ++i ) {
00323             vecpsOutputs[ i ]->Save( );
00324             delete vecpsOutputs[ i ]; } }
00325 
00326 #ifdef WIN32
00327     pthread_win32_process_detach_np( );
00328 #endif // WIN32
00329     return 0; }
00330 
00331 void* learn( void* pData ) {
00332     CDatasetCompact Data;
00333     SLearn*         psData;
00334     vector<string>  vecstrNames;
00335     vector<size_t>  veciZeros;
00336     size_t          i;
00337 
00338     psData = (SLearn*)pData;
00339 
00340     vecstrNames.push_back( psData->m_strInput );
00341     if( !Data.Open( *psData->m_pAnswers, vecstrNames, psData->m_psArgs->zero_flag || psData->m_psArgs->zeros_arg,
00342         !!psData->m_psArgs->memmap_flag, psData->m_psArgs->skip_arg, !!psData->m_psArgs->zscore_flag ) ) {
00343         cerr << "Couldn't open: " << psData->m_strInput << endl;
00344         return NULL; }
00345     vecstrNames.insert( vecstrNames.begin( ), psData->m_psArgs->answers_arg );
00346     for( i = 0; i < vecstrNames.size( ); ++i )
00347         vecstrNames[ i ] = CMeta::Filename( CMeta::Deextension( CMeta::Basename(
00348             vecstrNames[ i ].c_str( ) ) ) );
00349     veciZeros.resize( vecstrNames.size( ) );
00350     for( i = 0; i < veciZeros.size( ); ++i ) {
00351         map<string,size_t>::const_iterator  iterZero;
00352 
00353         veciZeros[ i ] = ( ( iterZero = psData->m_pmapZeros->find( vecstrNames[ i ] ) ) ==
00354             psData->m_pmapZeros->end( ) ) ? -1 : iterZero->second; }
00355     for( i = 0; i < psData->m_pvecpsOutputs->size( ); ++i ) {
00356         if( !( i % 50 ) )
00357             cerr << "Term " << i << '/' << psData->m_pvecpsOutputs->size( ) << endl;
00358         if( !(*psData->m_pvecpsOutputs)[ i ]->LearnNode( psData->m_iNode, *psData->m_pAnswers, &Data, vecstrNames, veciZeros,
00359             psData->m_psArgs->default_arg ? psData->m_pBNDefault : NULL, !!psData->m_psArgs->zero_flag ) )
00360             return NULL; }
00361 
00362     return NULL; }