DeepAI
Log In Sign Up

Neural Graphical Models

10/02/2022
by   Harsh Shrivastava, et al.
0

Graphs are ubiquitous and are often used to understand the dynamics of a system. Probabilistic Graphical Models comprising Bayesian and Markov networks, and Conditional Independence graphs are some of the popular graph representation techniques. They can model relationships between features (nodes) together with the underlying distribution. Although theoretically these models can represent very complex dependency functions, in practice often simplifying assumptions are made due to computational limitations associated with graph operations. This work introduces Neural Graphical Models (NGMs) which attempt to represent complex feature dependencies with reasonable computational costs. Specifically, given a graph of feature relationships and corresponding samples, we capture the dependency structure between the features along with their complex function representations by using neural networks as a multi-task learning framework. We provide efficient learning, inference and sampling algorithms for NGMs. Moreover, NGMs can fit generic graph structures including directed, undirected and mixed-edge graphs as well as support mixed input data types. We present empirical studies that show NGMs' capability to represent Gaussian graphical models, inference analysis of a lung cancer data and extract insights from a real world infant mortality data provided by CDC.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

11/29/2011

Structure Learning of Probabilistic Graphical Models: A Comprehensive Survey

Probabilistic graphical models combine the graph theory and probability ...
11/13/2022

Methods for Recovering Conditional Independence Graphs: A Survey

Conditional Independence (CI) graphs are a type of probabilistic graphic...
05/23/2022

uGLAD: Sparse graph recovery by optimizing deep unrolled networks

Probabilistic Graphical Models (PGMs) are generative models of complex s...
12/01/1996

Characterizations of Decomposable Dependency Models

Decomposable dependency models possess a number of interesting and usefu...
06/13/2012

Inference for Multiplicative Models

The paper introduces a generalization for known probabilistic models suc...
10/09/2017

Coresets for Dependency Networks

Many applications infer the structure of a probabilistic graphical model...
01/23/2021

Bayesian Edge Regression in Undirected Graphical Models to Characterize Interpatient Heterogeneity in Cancer

Graphical models are commonly used to discover associations within gene ...

1 Introduction

Graphical models are a powerful tool to analyze data. They can represent the relationship between the features and provide underlying distributions that model functional dependencies between them. Probabilistic graphical models (PGMs) are quite popular and often used to describe various systems from different domains. Bayesian networks (directed acyclic graphs) and Markov networks (undirected graphs) are able to represent many complex systems due to their generic mathematical formulation 

pearl88; koller2009probabilistic

. These models rely on conditional independence assumptions to make representation of the domain and the probability distribution over it feasible.

Learning, inference and sampling are operations that make such graphical models useful for domain exploration. Learning, in a broad sense, consists of fitting the distribution function parameters from data. Inference is the procedure of answering queries in the form of marginal distributions or reporting conditional distributions with one or more observed variables. Sampling is the ability to draw samples from the underlying distribution defined by the graphical model. One of the common bottlenecks of graphical model representations is having high computational complexities for one or more of these procedures. Figuring out approximate algorithms or coming up with analytically favorable underlying distributions have been topics of interest to the research community for the past few decades.

In particular, various graphical models have placed restrictions on the set of distributions or types of variables in the domain. Some graphical models work with continuous variables only (or categorical variables only) or place restrictions on the graph structure (e.g., that continuous variables cannot be parents of categorical variables in a DAG). Other restrictions affect the set of probability distributions the models are capable of representing, e.g., to multivariate Gaussian.

Practically, for graphical models to be widely adoptable, the following properties are desired:

  • [leftmargin=*,nolistsep]

  • Facilitate rich representations of complex underlying distributions.

  • Support various relationship representations including directed, undirected, mixed-edge graphs.

  • Fast and efficient algorithms for learning, inference and sampling.

  • Direct access to the learned underlying distributions for analysis.

  • Handle different input data types such as categorical, continuous, images, text, and generic embedding representations.

In this work we propose Neural Graphical Models (NGMs) that satisfy the aforementioned desiderata in a computationally efficient way. NGMs accept a feature dependency structure that can be given by an expert or learned from data. The dependency structure may have the form of a graph with clearly defined semantics (e.g., a Bayesian network graph or a Markov network graph) or an adjacency matrix. Note that the graph may be either directed or undirected. Based on this dependency structure, NGMs represent the probability function over the domain using a deep neural network. The parameterization of such a network can be learned from data efficiently, with a loss function that jointly optimizes adherence to the given dependency structure and fit to the data. Probability functions represented by NGMs are unrestricted by any of the common restrictions inherent in other PGMs. They also support efficient inference and sampling.

The rest of this paper is organized as follows: in Section 2 we briefly review work most closely related to ours, in Section 3 we introduce Neural Graphical Models including representation, learning, inference, sampling and handling of extended data types. We present experiments, both on synthetic and real-life data in Section 4 and Appendix B, discuss design considerations and limitations of our framework in Appendix A and close with conclusions and directions for future work in Section 5.

2 Related work

Probabilistic graphical models aim to learn the underlying joint distribution from which input data is sampled. Often, to make learning of the distribution computationally feasible, inducing an independence graph structure between the features helps. In cases where this independence graph structure is provided by a domain expert, the problem of fitting PGMs reduces to learning distributions over this graph. Alternatively, there are many methods traditionally used to jointly learn the structure as well as the parameters 

heckerman1995learning; spirtes1995learning; koller2009probabilistic; scanagatta2019survey and have been widely used to analyse data in many domains barton2012bayesian; bielza2014bayesian; borunda2016bayesian; shrivastava2019cooperative; shrivastava2020using.

A few researchers explored discriminative PGMs, learning not joint probability distributions over a domain, but an approximation to a conditional distribution where is a pre-selected subset of , typically in the context of undirected graphs. The best known are conditional random fields (CRF) Lafferty2001ConditionalRF. Discriminative models are more flexible in ignoring complex dependencies between most of the variables in the domain and focusing on their impact on a small subset. They often have faster and more accurate inference, albeit restricted to the pre-selected set of variables. Generative models have higher bias – they make more assumptions about the form of the distribution. The bias helps with regularization and avoiding overfitting. However, generative models are poorer predictors than discriminative models. In this work, we attempt to combine the advantages of both methods by creating a discriminative model capable of predicting the value of any variable in a domain.

Recently, many interesting deep learning based approaches for DAG recovery have been proposed zheng2018dags; zheng2020learning; lachapelle2019gradient; yu2019dag

. These works primarily focus on the structure learning but technically they are learning a probabilistic graphical model. These works depend on the existing algorithms developed for the Bayesian networks for the inference and sampling tasks. A parallel line of work combining graphical models with deep learning are Bayesian deep learning approaches: Variational AutoEncoders, Boltzmann Machines etc. 

(wang2020survey). The deep learning models have significantly more parameters than traditional Bayesian networks. Thus, using these deep graphical models for downstream tasks is computationally expensive and often impedes their adoption.

We would be remiss not to mention the technical similarities NGMs have with some recent research works. First, we found ‘Learning sparse nonparametric DAGs’ zheng2020learning to be the closest in terms of representation ability. In one of their versions, they model each independence structure with a different neural network (MLP). However, their choice of modeling feature independence criterion differs from NGM. They zero out the weights of the row in the first layer of the NN to induce independence between the input and output features. This type of formulation restricts them from sharing the NNs across different factors. Second, we found similar path norm formulations of using the product of NN weights for input to output connectivity for NGMs in  lachapelle2019gradient. They use the path norm to parametrize the DAG constraint for continuous optimization, while shrivastava2020grnular; shrivastava2022grnular use the within unrolled algorithm framework to learn sparse gene regulatory networks.

There are methods that model the conditional independence graphs (friedman2008sparse; belilovsky2017learning; shrivastava2019glad; shrivastava2022uglad)

which are a type of graphical models that are based on underlying multivariate Gaussian distribution. Probabilistic Circuits 

(peharz2020einsum), Conditional Random Fields or Markov Networks (sutton2012introduction)are some other popular formulations. These PGMs often make simplifying assumptions on the underlying distributions and have certain restrictions on the input data type that can be handled. Real-world input data often consist of mixed datatypes (real, categorical, text, images etc.) and is challenging for the existing graphical model formulations to handle.

Figure 1: Graphical view of NGMs: The input graph G (undirected) for given input data . Each feature is a function of the neighboring features. For a DAG, the functions between features will be defined by the Markov Blanket relationship . The adjacency matrix (right) represents the associated dependency structure S.
Figure 2: Neural view of NGMs: NN as a multitask learning architecture capturing non-linear dependencies for the features of the undirected graph in Fig. 1. If there is a path from the input feature to an output feature, that indicates a dependency between them. The dependency matrix between the input and output of the neural network reduces to a simple matrix multiplication operation . Note that not all the zeroed out weights of the MLP (in black-dashed lines) are shown for the sake of clarity.

3 Neural Graphical Models

We propose a new probabilistic graphical model type, called Neural Graphical Models (NGMs) and describe the associated learning, inference and sampling algorithms. Our model accepts all input types and avoids placing any restrictions on the form of underlying distributions.

3.1 Problem setting

We are given input data X that have sample points with each sample consisting of features. An example of such data can be gene expression data, where data is a matrix of the microarray expression values (samples) and genes (features). Another example is a mix of continuous and categorical data describing a patient’s health in a medical domain. We are also provided a graph G which can be directed, undirected or have mixed-edge types that represents our belief about the feature dependency relationships (in a probabilistic sense). Such graphs are often provided by experts and include inductive biases and domain knowledge about the underlying system functions. In cases where the graph is not provided, we make use of the state-of-the-art algorithms to recover DAGs or CI graphs, as described in Sec. 2. The NGM input is the tuple (X, G).

3.2 Representation

Function proximal-init():
       Init MLP using dimensions from S (Using ‘adam’ optimizer for epochs) return
Function fit-NGM():
       For  do
             =     backprop to update params (optional update) Detach from the computational graph
      return
Function NGM-learning():
       proximal-init() fit-NGM() return
Algorithm 1 NGMs: Learning algorithm

Fig. 1 shows a sample graph recovered and how we view the value of each feature as a function of the values of its neighbors. In the case of directed graphs, each feature’s value is represented as a function of its Markov blanket in the graph. We use the graph G to understand the domain’s dependency structure, but ignore any potential parametrization associated with it.

We introduce a ‘neural’ view which is another way of looking at G, represented in Fig. 2

. These neural networks are multi-layer perceptrons with appropriate input and output dimensions that represent graph connections in NGMs. Specifically, we view the neural networks as an ‘open-box’ and focus on the paths from input to output. These paths represent functional dependencies. Consider a neural network with H number of layers having ReLU non-linearity

. The dimensions of the weights and biases are chosen such that the neural network input and output units are equal to . The product of the weights of the neural networks gives us path dependencies. If then the output does not depend on input . Increasing the layers and hidden dimensions of the NNs will provide us with richer dependence function complexities.

Representing categorical variables. Assume that in the input X, we have a column having

different categorical entries. One way to handle categorical input is to do one-hot encoding on the column

and end up with different columns, . We replace the single categorical column with the corresponding one-hot representation in the original data. The MLP capturing path dependencies S will need to be updated accordingly. Whatever connections where previously connected to the categorical column should be maintained for all the one-hot columns as well. Thus, we connect all the one-hot columns to represent the same path connections as the original categorical column.

3.3 Learning

Using the rich and compact functional representation achieved by using the ‘neural’ view, the learning task is to fit the neural networks to achieve the desired dependency structure S, along with fitting the regression to the input data X. Given the input data X we want to learn the functions as described by the NGMs ‘graphical-view’, Fig. 1. These can be obtained by solving the multiple regression problems shown in neural view, Fig. 2. We achieve this by considering the neural view as a multi-task learning framework. The goal is to find the set of parameters that minimize the loss expressed as the distance from to while maintaining the dependency structure provided in the input graph G. We can define the regression operation as follows:

(1)

Here, represents the compliment of the matrix , which essentially replaces by and vice-versa. The represents the hadamard operator which does an element-wise matrix multiplication between the same dimension matrices . Including the constraint as a lagrangian term with penalty and a constant that acts a tradeoff between fitting the regression and matching the graph dependency structure, we get the following optimization formulation

(2)

Though the bias term is not explicitly written in the optimization to avoid cluttering, we learn the weights and the biases while optimizing for Eq. 2. In our implementation, the individual weights are normalized using -norm before taking the product. We normalize the regression loss and the structure loss term separately, so that both the losses are on a similar scale while training and recommend the range of =[1e-2, 1e2]. Appropriate scaling is applied to the input data features.

Function gradient-based():
       , split the data

fixed tensor (known)

learnable tensor (unknown) freeze weights do
             = = updated by backprop on
      while return
Function message-passing():
       , split the data while  do
             =
       return
Function NGM-inference():
       Input: trained NGM model   (mean values for unknown) message-passing () or gradient-based () return
Algorithm 2 NGMs: Inference algorithm

Proximal Initialization strategy: To get a good initialization for the NN parameters and we implement the following procedure. We solve the regression problem described in Eqn. 1 without the structure constraint. This gives us a good initial guess of the NN weights . We choose the value and update after each epoch. Experimentally, we found that this strategy may not work optimally in few cases and in such cases we recommend fixing the value of at the beginning of the optimization. The value of can be chosen such that it brings the regression loss and the structure loss values to same scale.

The learned NGM describes the underlying graphical model distributions, as presented in Alg. 1. There are multiple benefits of jointly optimizing in a multi-task learning framework modeled by the neural view of NGMs, eq. 2. First, sharing of parameters across tasks helps in significantly reducing the number of learning parameters. It also makes the regression task more robust towards noisy and anomalous data points. Second, we fully leverage the expressive power of the neural networks to model complex non-linear dependencies. Additionally, learning all the functional dependencies jointly allows us to leverage batch learning powered with GPU based scaling to get quicker runtimes.

3.4 Inference

Inference is the process of using the graphical model to answer queries. Calculation of marginal distributions and conditional distributions are key operations for inference. Since NGMs are discriminative models, for the prior distributions, we follow the frequentist approach and directly calculate them from the input data. We consider two iterative procedures to answer conditional distribution queries over NGMs described in Alg. 2. We split the input data into two parts, denotes the known (observed) variable values and denotes the unknown (target) variables. The inference task is to predict the values of the unknown nodes based on the trained NGM model distributions. In the fist approach, we use the popular message passing algorithms that keeps the observed values of the features fixed and iteratively updates the values of the unknowns until convergence. We developed an alternative algorithm which is efficient and is our recommended approach to do inference in NGMs.

Gradient based approach: The weights of the trained NGM model are frozen once trained. The input data is divided into fixed (observed) and learnable (target) tensors. We then define a regression loss over the known attribute values as we want to make sure that the prediction matches values for the observed features. Using this loss we update the learnable input tensors until convergence to obtain the values of the target features. Since the NGM model is trained to match the output to the input, we can view this procedure of iteratively updating the unknown features such that the input and output matches. Based on the convergence loss value reached after the optimization, one can assess the confidence in the inference. Furthermore, plotting the individual feature dependency functions also helps in gaining insights about predicted values.

Function get-sample():
       = len() (random init, learnable tensor) Sample feature value from empirical marginal distribution For  do
             (fixed tensor) (learnable tensor)
      return
Function NGM-sampling(, G):
       Input: trained NGM model Randomly choose ’th feature =BFS(G,) [undirected]   queue the features =topological-sort(G) [DAGs] get-sample () return
Algorithm 3 NGMs: Sampling algorithm

Obaining probability distributions.

It is often desirable to get the full probability density function rather than just a point value for any inference query. In case of categorical variables, this is readily obtained as we output a distribution over all the categories. For real or numerical features, we consider a binned input on the input side and real value output. In this case, the regression term of the loss function, Eq. 

3 will take binned input and output a real value for the real valued features . In practice, given a distribution over different categories obtained during the NGM inference, we clip the individual values between and then divide by the total sum to get the final distribution.

3.5 Sampling

One common way of sampling is to define cumulative density functions and then sample from them. This will not be possible for NGMs. So, instead, we propose a procedure akin to Gibbs sampling as described in Alg. 3.

We based our sampling procedure to follow . Note that nbrs will be MB for DAGs. We start sampling by choosing a feature at random. To get the order in which the features will be sampled, we do a Breadth-first-search (topological sort in DAGs) and arrange the nodes in . In this way, the immediate neighbors are chosen first and then the sampling spreads over the graph away from the starting feature. As we go through the ordered features in the sampling procedure, we sample the value of each feature from the conditional distribution based on previously assigned values and then keep it fixed for the subsequent iterations (feature is now observed). We then call the inference algorithm conditioned on these fixed features to get the distributions over the unknown features. This process is repeated till we get a sample value of all the features.

Our sampling procedure differs from the Gibbs sampling with regards to conditional distribution calculations. Traditionally, in Gibbs sampling, sample is derived from the previous sample by following a conditional distribution update. Specifically, the value of is obtained according to the distribution specified by . The new sample of the NGM is not derived from the previous sample, hence we avoid the ‘burn-in’ period issue with Gibbs sampling where one has to ignore the initial set of samples. The conditional updates for the NGMs are of the form, . We keep on fixing the value of features and run inference on the remaining features until we have obtained the values of all the features and thus get a new sample. The inference algorithm of the NGM facilitates conditional inference on multiple unknown features over multiple observed features. We leverage this capability of the inference algorithm for faster sampling from NGMs.

3.6 Extension to generic data types

The learning, inference and sampling algorithms proposed for NGMs in the previous section can be extended to any generic input data type. This implies that the data X can be real, categorical, image or have an embedding based representation. We add a Projection module consisting of an encoder and decoder that act as a wrapper around the neural view of the NGMs. With a slight modification, we obtain the following optimization for generic data types,

(3)

The Projection module can be jointly learned in the optimization, as shown in Eq. 3, or one can add fine-tuning layers to the pretrained versions depending on the data type and user preference.

Alternatively, one can extend the idea of soft-thresholding the connection patterns to the encoder and decoder networks. This will result in an efficient training strategy that leverages batch processing.

(4)

where, the connectivity of the input and the input to the neural view is modeled by the sparsity term for the encoder network’s sparsity pattern . Similar procedure is applied to the decoder side.

Figure 3: Neural view with Projection modules of NGMs: The input X can be one-hot (categorical), image or in general an embedding (text, audio, speech and other data types). Projection modules (encoder + decoder) are used as a wrapper around the neural view of NGMs. The architecture choice of the projection modules depends on the input data type and users’ design choices. Note that the output of the encoder can be more than 1 unit ( can be a hypernode). In that case, we just need to adjust the graph dependency structure S to account for that many units and the corresponding feature connections. Same will be the case with the decoder side of the architecture. The remaining details are similar to the ones described in Fig. 2

If the Projection modules are used, the number of nodes in the neural view input should be adjusted according to the output units of the encoder. Similar adjustment is needed for neural view output and the decoder. In real world applications, we often find inputs consisting of mixed datatypes. For instance, in the gene expression data, there can be additional meta information (categorical) or images associated with the genes. Optionally, one can desire to utilize node embeddings from some other pretrained deep learning models. NGMs are designed to handle such mixed input data types simultaneously which are otherwise very tricky to accommodate in the existing graphical models.

4 Experiments

We evaluate NGMs on synthetic and real data. Appendix A contains some best practices that we developed while working with NGMs. In Appendix B, we present an analysis of CDC’s Infant Mortality Data (CDC:InfantLinkedDatasets) using NGMs, which highlights NGMs-generic architecture’s ability to model mixed input datatypes.

4.1 Modeling Gaussian Graphical models

We designed a synthetic experiment to study the capability of NGMs to represent Gaussian graphical models. The aim of this experiment is to see (via plots and sampling) how close are the distributions learned by the NGMs to the GGMs.

Figure 4: The leftmost graph shows the chain graph G (partial correlations in green are positive, red are negative, thickness shows the correlations strength) obtained from the initialized partial correlation matrix. Samples were drawn from the GGM. NGM was learned on the input (X, G). The 2 plots on the right show the dependency functions of NGM and GGM for a particular node by varying its neighbor’s values. The positive and negative correlations are reflected in the slope of the curve, as expected analytically. We then sampled from the learned NGM to obtain data . The graph, second from the left, shows the recovered graph by running uGLAD  shrivastava2022uglad on Xs. We can observe that it missed some of the edges but most of the connections along with the correlations signs were retrieved from the NGM samples.

[capbesideposition=left, center,capbesidewidth=8.5cm]table[0.999] Samples AUPR AUC 1000 2000 4000

Table 1: The recovered CI graph from NGM samples is compared with the CI graph defined by the GGMs precision matrix. Area under the ROC curve (AUC) and Area under the precision-recall curve (AUPR) values for 10 runs are reported, refer to Fig. 4.

Setup: Define the underlying graph. We defined a ‘chain’ (or path-graph) containing D nodes as the underlying graph. We chose this graph as it allows for an easier study of dependency functions.

Fit GGM and get samples. Based on the underlying graph structure, we defined a precision matrix that randomly samples its entries from . We then used this precision matrix as a multivariate Gaussian distribution parameter to obtain the input sample data X. We get the corresponding partial correlation graph G by using the formula, .

Fit NGM and get samples. We fit a NGM on the input (X, G). We chose with 2 layers and non-linearity for the neural view’s MLP. Training was done by optimizing eq. 2 for the input, refer to Fig. 4. Then, we obtained data samples Xs from the learned NGM.

Analysis: ‘How close are the GGM and NGM samples?’ We recover the graph using the graph recovery algorithm uGLAD on the sampled data points from NGMs and compare it with the true CI graph. Table 1 shows the graph recovery results of varying the number of samples from NGMs. We observe that increasing the number of samples improves the graph recovery, which is expected.

‘Were the NGMs able to model the underlying distributions?’ The functions plot (on the right) in Fig. 4 plots the resultant regression function for a particular node as learned by NGM. This straight line with the slope corresponding to the partial correlation value is what we expect theoretically for the GGM chain graph. This is also an indication that the learned NGMs were trained properly and reflect the desired underlying relations. Thus, NGMs are able to represent GGM models.

Figure 5: (left) The CI graph recovered by uGLAD for the Lung cancer data. Plots on the right show the conditional distribution for the features P(Lung cancer=’Yes’| nbrs(Lung cancer)) and P(Smoking| nbrs(Smoking)) based on their neighbors. We used a 2-layer NGM with hidden size and non-linearity as . NGMs are able to capture the non-linear dependencies between the features. Interestingly the NGMs match the relationship trends discovered (positive and negative correlations) by the corresponding CI graph.

4.2 Lung cancer data analysis

We analysed a lung cancer data on  lcData

using NGMs. The effectiveness of cancer prediction system helps people to know their cancer risk with low cost and it also helps people to take appropriate decisions based on their cancer risk status. This data contains 284 instances of patients and for each patient 16 features (Gender, Smoking, Anxiety, Lung cancer present, etc.) are collected. Each entry is a binary entry (YES/NO) or in some cases (AGE), entries are binarized. Particularly, we used NGMs to study how different features are related and discover their underlying functional dependencies.

Methods Lung-cancer Smoking
LR
NGM
Table 2: 5-fold CV results.

The input data along with the CI graph recovered using uGLAD were used to learn a NGM in Fig. 5

. In order to gauge the regression quality of NGMs, we compare with logistic regression to predict the probability of feature values given the values of the remaining features. Table. 

2 shows regression results of logistic regression (LR) and NGMs on 2 different features, ‘lung cancer’ & ‘smoking’. The prediction probability for NGMs were calculated by running inference on each test datapoint, eg. P(lung-cancer=‘yes’| in test data). This experiment primarily demonstrates that a single NGM model can robustly handle fitting multiple regressions and one can avoid training a separate regression model for each feature while maintaining at-par performance. Furthermore, we can obtain the dependency functions that bring in more interpretability for the predicted results, Fig. 5. Samples generated from this NGM model can be used for multiple downstream analyses.

5 Conclusions

This work attempts to improve the usefulness of probabilistic graphical models by extending the range of input data types and distribution forms such models can handle. Neural Graphical Models provide a compact representation for a wide range of complex distributions and support efficient learning, inference and sampling. The experiments are carefully designed to systematically explore the various capabilities of NGMs. Though NGMs can leverage GPUs and distributed computing hardware, we do forsee some challenges in terms of scaling in number of features and performance on very high-dimensional data. Using NGMs for images & text based applications will be interesting to explore. We believe that NGMs is an interesting amalgam of the deep learning architectures’ expressivity and Probabilistic Graphical models’ representation capabilities.

Upcoming version: Discovering the dependency graph with NGMs. We are currently working on a version of NGM that can jointly discover the feature dependency graph along with fitting the regression. One way can be to optimize this loss function,

(5)

where has diagonal entries as . Essentially, we start with a fully connected graph and then the term induces sparsity. This will be helpful in cases where input G is not provided.

References

Appendix A Design strategies and best practices for NGMs

We share some of the design strategies and best practices that we developed while working with NGMs here. This is to give insights to the readers on our approach and help them narrow down the architecture choices of NGMs for applying to their data. We hope that sharing our thought process and findings here will foster more transparency, adoption and help identify potential improvements to facilitate the advancement of research in this direction.

  • [leftmargin=*,nolistsep]

  • Choices for the structure loss function. We narrowed down the loss function choice to Hadamard loss vs square loss . We also experimented with various choices of Lagrangian penalties for the structure loss. We found that worked better in most cases. Our conclusion was to use Hadamard loss with either vs penalty.

  • Strategies for initialization. (I) Keep it fixed to balance between the initial regression loss and structure loss. We utilize the loss balance technique mentioned in rajbhandari2019antman. (II) Use the proximal initialization technique clubbed with increasing value as described in Alg. 1. Both the techniques seem to work well, although (I) is simpler to implement and gives equivalent results.

  • Selecting width and depth of the neural view. We start with hidden layer size twice the input dimension. Then based on the regression and structure loss values, we decide whether to go deeper or have a larger number of units. In our experience, increasing the number of layers helps in reducing the regression loss while increasing the hidden layer dimensions works well to optimize for the structure loss.

  • Choices of non-linearity. For the MLP in the neural view, we played around with multiple choices of non-linearities. We ended up using ReLU, although gave similar results.

  • Handling imbalanced data. NGMs can also be adapted to utilize the existing imbalanced data handling techniques chawla2002smote; shrivastava2015classification; bhattacharya2017icu which improved results in our experience.

  • Calculate upper bound on regression loss. Try fitting NGM by assuming fully connected graph to give the most flexibility to regression. This way we get an upper bound on the best optimization results on just the regression loss. This helps to select the depth and dimensions of MLPs required when the sparser structure is imposed.

  • Convergence of loss function. In our quest to figure out a way to always get good convergence on both the losses (regression & structure), we tried out various approaches. (I) Jointly optimize both the loss functions with a weight balancing term , Eq. 2. (II) We tested out an Alternating Method of Multipliers (ADMM) based optimization that alternately optimizes for the structure loss and regression loss. (III) We also ran a proximal gradient descent approach which is sometimes suitable for loss with regularization terms. Choice (I) turned out to be effective with reasonable values.

In the current state, it can be tedious to optimize NGMs and needs decent amount of experimentation. It is a learning experience for us as well and we are always on a lookout to learn new techniques from the research community.

Appendix B Infant Mortality analysis

We created an NGM to model infant mortality data. The dataset is based on CDC Birth Cohort Linked Birth – Infant Death Data Files CDC:InfantLinkedDatasets. It describes pregnancy and birth variables for all live births in the U.S. together with an indication of an infant’s death before the first birthday. We used the data for 2015 (latest available), which includes information about 3,988,733 live births in the US during 2015 calendar year.

We recovered the graph strucure of the dataset using uGLAD (shrivastava2022uglad) and using Bayesian network package bnlearn (bnlearn) with Tabu search and AIC score. The graphs are shown in Fig. 7 and 6 respectively. Since bnlearn does not support networks containing both continuous and discrete variables, all variables were converted to categorical for bnlearn structure learning and inference. In contrast, uGLAD and NGMs are both equipped to work with mixed types of variables and were trained on the dataset prior to conversion.

Both graphs show similar sets of clusters with high connectivity within each cluster:

  • [leftmargin=*,nolistsep]

  • describing both parents’ race and ethnicity (mrace and frace variables),

  • related to mother’s bmi, height (mhtr) and weight, both pre-pregnancy (pwgt_r) and at delivery (dwgt_r),

  • consisting of maternal morbidity variables marked with mm prefix (e.g., unplanned hysterectomy),

  • showing pregnancy related complications such as hypertension and diabetes (variables prefixed with rf and urf),

  • consisting of variables related to parents’ STD infections (ip prefix),

  • related to delivery complications and interventions (variables prefixed with ld),

  • showing interventions after delivery (ab prefix) such as ventilation or neonatal ICU,

  • describing congenital anomalies diagnosed in the infant at the time of birth (variables prefixed with ca),

  • related to infant’s death: age at death, place, autopsy, manner, etc.

Figure 6: The Bayesian network graph learned using score-based method for the Infant Mortality 2015 data.
Figure 7: The CI graph recovered by uGLAD for the Infant Mortality 2015 data.
Figure 8: Comparing the graphs recovered by uGLAD and Bayesian Network recovery package (bnlearn) after moralization (moralized edges are denoted by ‘skyblue’).

Apart from these clusters, there are a few highly connected variables in both graphs: gestational age (combgest and oegest), delivery route (rdmeth_rec), Apgar score, type of insurance (pay), parents’ ages (fage and mage variables), birth order (tbo and lbo), and prenatal care.

With all these similarities, however, the total number of edges varies greatly between the two graphs and the number of edges unique to each graph outnumbers the number of edges the two graphs have in common (see Figure 8).One reason for the differences lies in the continuous-to-categorical conversion performed prior to Bayesian network structure discovery and training. The two graph recovery algorithms are very different in both algorithmic approach and objective function. We plan to further explore NGMs’ sensitivity to input graph recovery algorithm in future work.

Infant mortality dataset is particularly challenging, since cases of infant death during the first year of life are (thankfully) rare. Thus, any queries concerning such low probability events are hard to estimate with accuracy.

NGM-generic architecture: Since we have mixed input data type, real and categorical data, we utilize the NGM-generic architecture as shown in Fig. 3. We consider a 2-layer neural view with hidden layer dimension as

. The categorical input was converted to its one-hot vector representation and added to the real features which gave us roughly

features as input. The neural view input from the encoder had the same dimension as input. Similarly, we maintained same dimension from the neural view output to the decoder output. The entire NGM-generic parameters were learned by minimizing the eq. 4 using the ‘adam’ optimizer.

Sensitivity to the input graph: To study the effect of different graph structures on NGMs, we train separate models on the Bayesian Network graph (after moralizing) and the CI graph from uGLAD given in Fig. 6 & 7 respectively. We plot the dependency functions between pairs of nodes based on the common and unique edges found in the comparison plots of Fig. 8. For each pair of features, say , the dependency function is obtained by running inference by varying the value of over its range as shown in Fig. 9.

Figure 9: Evaluating effects of varying input graphs for learning NGMs. Comparing the NGM dependency plots recovered by using Bayesian Network graph vs the CI graph obtained by running uGLAD . Similar architecture of NGMs were chosen and the data preprocessing was also kept as alike as possible. For the feature pairs in the top box, the trends match for both the graphs, while in the bottom box the dependency plots differ. We observed that the dependency trends discovered by the NGM trained on the CI graph matches the correlation of the CI graph. Common edges present in both the graphs [(pwgt-r, dwgt-r), (wtgain, mhtr), (bmi, mhtr), (precare, pay)], edges only present in CI graph [(wtgain, dmar), (wtgain, bmi)]. It is interesting to observe that even for some common edges, eg. (wtgain, mhtr), that represents strong direct dependence between the features, the trends can still differ significantly. This highlights the importance of the input graph structure chosen to train NGMs.

Comparing NGM inference in models with different input graphs shows some interesting patterns:

  • [leftmargin=*,nolistsep]

  • Strong positive correlation of mother’s delivery weight (dwgt_r) with pre-pregnancy weight (pwgt_r) is shown in both models.

  • Similarly, both models show that married mothers (dmar) are likely to gain more weight than unmarried (dmar).

  • Both models agree that women with high BMI tend to gain less weight during their pregnancies than women with low BMI.

  • A discrepancy appears in cases of the dependence of both BMI and weight gain during pregnancy on mother’s height (mhtr). According to the NGM trained with a BN graph, higher weight gain and higher BMI are more likely for tall women, while the CI-trained NGM shows the opposite.

  • Possibly the most interesting are the graphs showing the dependence of the timing a women starts prenatal care (precare specifies the month of pregnancy when prenatal care starts) on the type of insurance she carries. For both models, Medicaid (1) and private insurance (2) mean early start of care and there is a sharp increase (delay in prenatal care start) for self-pay (3) and Indian Health Service (4). Models disagree to some extent on less common types of insurance (military, government, other, unknown).

Our experiments on infant mortality dataset demonstrate usefulness of NGMs to model complex mixed-input real-world domains. We are currently running more experiments designed to capture more information on NGMs’ sensitivity to input graph recovery algorithm and inference accuracy.