Robust MIMO Detection using Hypernetworks with Learned Regularizers

by   Nicolas Zilberstein, et al.
Rice University

Optimal symbol detection in multiple-input multiple-output (MIMO) systems is known to be an NP-hard problem. Recently, there has been a growing interest to get reasonably close to the optimal solution using neural networks while keeping the computational complexity in check. However, existing work based on deep learning shows that it is difficult to design a generic network that works well for a variety of channels. In this work, we propose a method that tries to strike a balance between symbol error rate (SER) performance and generality of channels. Our method is based on hypernetworks that generate the parameters of a neural network-based detector that works well on a specific channel. We propose a general framework by regularizing the training of the hypernetwork with some pre-trained instances of the channel-specific method. Through numerical experiments, we show that our proposed method yields high performance for a set of prespecified channel realizations while generalizing well to all channels drawn from a specific distribution.



There are no comments yet.


page 1

page 2

page 3

page 4


Deep HyperNetwork-Based MIMO Detection

Optimal symbol detection for multiple-input multiple-output (MIMO) syste...

Grassmannian Constellation Design for Noncoherent MIMO Systems Using Autoencoders

In this letter, we propose an autoencoder (AE) for designing Grassmannia...

Reliable and Low-Complexity MIMO Detector Selection using Neural Network

In this paper, we propose to dynamically select a MIMO detector using ne...

Deep-Learning Based Linear Precoding for MIMO Channels with Finite-Alphabet Signaling

This paper studies the problem of linear precoding for multiple-input mu...

Adaptive Neural Signal Detection for Massive MIMO

Symbol detection for Massive Multiple-Input Multiple-Output (MIMO) is a ...

Tightness and Equivalence of Semidefinite Relaxations for MIMO Detection

The multiple-input multiple-output (MIMO) detection problem, a fundament...

A Modular Neural Network Based Deep Learning Approach for MIMO Signal Detection

In this paper, we reveal that artificial neural network (ANN) assisted m...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Multiple-input multiple-output (MIMO) systems are an essential part of modern communications [mimoreview1], [mimoreview2]. Moreover, they are expected to play a fundamental role in moving from the fifth to the sixth generation of cellular communications by achieving high data rates and spectral efficiency [6g]. In MIMO systems, base stations are equipped with multiple antennas, enabling them to handle several users simultaneously. However, these systems entail many challenges such as performing efficient symbol detection, which is the focus of our paper.

Exact MIMO detection is an NP-hard problem [Pia2017MixedintegerQP]. Given users and a modulation of

symbols, the exact maximum likelihood (ML) estimator has an exponential complexity

. Thus, obtaining this ML estimate is computationally infeasible and becomes intractable even for moderately-sized systems. Many approximate solutions for symbol detection have been proposed in the classical literature including zero forcing (ZF) and minimum mean squared error (MMSE) [Proakis2007]. Although both (linear) detectors have low complexity and good performance for small systems, their performance degrades severely for larger systems [chockalingam_rajan_2014]. Another classical detector is approximate message passing (AMP), which is asymptotically optimal for large MIMO systems with Gaussian channels, but degrades significantly for other (more pratical) channel distributions [amp].

Recently, machine learning and, in particular, deep learning have been proposed to solve fundamental problems in wireless communications such as power allocation 

[eisen2020regnn, UWMMSE, chowdhury2021ml], link scheduling [linkscheduling, zhao2021link], and random access control [kumar2021adaptive]. For the particular case of MIMO symbol detection, several solutions have been derived [mmnet, detNet2017, hypermimo, remimo]. At a high level, one can categorize the existing methods into two classes: 1) channel-specific methods learn to perform symbol detection for a prespecified channel realization, and 2) channel-agnostic methods can perform symbol detection for a wide variety of channels, typically drawn from a distribution of interest. MMNet [mmnet] and the so-called fixed channel version of DetNet [detNet2017] are examples of channel-specific methods whereas HyperMIMO [hypermimo], RE-MIMO [remimo], and the varying channel version of DetNet [detNet2017] are examples of channel-agnostic methods. Naturally, the first class attains very high performance for the channels in which they were trained but fail in other channels from the same distribution. However, as they have to be trained for each channel realization, they are typically unsuitable for real-time applications. In contrast, the second class generalizes well across a distribution without the need of retraining but cannot match the performance of the first class on the specific channels where they were trained.

Our goal is to combine these two classes of methods to attain a solution that yields very high performance for a set of prespecified channels (and their perturbations) and generalizes well to all channels coming from a distribution of interest without the need to retrain. We achieve this by first constructing a channel-agnostic methods based on a channel-specific one using the concept of hypernetworks [ha2016hypernetworks], and then regularizing the training of the hypernetwork with several pre-trained instances of the channel-specific method. Although the framework proposed is generic in terms of which channel-specific detector to choose, we focus on MMNet [mmnet] and its corresponding hypernetwork extension, the HyperMIMO [hypermimo].

Contribution. The contributions of this paper are twofold:
1) We propose a learning-based solution for MIMO detection that yields high accuracy for perturbations of a prespecified set of channels while generalizing to a whole distribution. We attain this via a HyperMIMO architecture whose training is regularized by solutions of the MMNet.
2) Through numerical experiments, we demonstrate that the proposed solution achieves symbol error rates below those obtained by HyperMIMO and MMNet trained separately while maintaining the (forward-pass) computational complexity of HyperMIMO.

2 System model and problem formulation

We consider a communication channel with single-antenna transmitters or users and a receiving base station with antennas. The forward model for this MIMO system is given by


where is the channel matrix,

is a vector of complex circular Gaussian noise,

is the vector of transmitted symbols, is a finite set of constellation points, and

is the received vector. In this work, a quadrature amplitude modulation (QAM) is used and each symbol is normalized to unit average power. It is assumed that the constellation is the same for all transmitters and each symbol has the same probability of being chosen by the users

. Moreover, perfect channel state information (CSI) is assumed, which means that and are known at the receiver.111To avoid notation overload, we adopt the convention that whenever we assume to be known, is also known. Under this setting, the MIMO detection problem can be defined as follows.

Problem 1

Given perfect CSI and an observed following (1), find an estimate of .

Given the stochastic nature of in (1), a natural way of solving Problem 1 is to search for the that maximizes the probability of observing our given . Unfortunately, such an ML detector boils down to solving the optimization problem


which is NP-hard due to the finite constellation constraint  [Pia2017MixedintegerQP], rendering intractable in practical applications. Consequently, several schemes have been proposed in the last decades to provide efficient approximate solutions to Problem 1, as mentioned in Section 1.

The classical body of work (ZF, MMSE, AMP) focuses on solving a single instance of Problem 1 for arbitrary and , which must then be repeated to recompute the detection in successive communication instances. Given that in practice we are interested in solving several instances of Problem 1 across time, a learning-based body of work has gained traction in the past years [mmnet, hypermimo, detNet2017, remimo]. In a nutshell, based on many tuples , the idea is to learn a map – a function approximator – from the space of observations and CSI to the corresponding (approximate) transmitted symbols . In this way, when a new observation is received (along with the CSI), can be efficiently estimated using the learned map without the need for solving an optimization problem.

Having introduced this framework, we can provide a precise distinction between the families of learning-based methods that we denominated as channel-specific and channel-agnostic. Channel-specific methods like DetNet [detNet2017] (for the fixed channel case) and MMNet [mmnet] learn a different function for every , i.e., they learn a function such that is a good solution to Problem 1 for a specific of interest. On the other hand, channel-agnostic methods like HyperMIMO [hypermimo] and RE-MIMO [remimo] consider the CSI as input to their learnable functions, i.e., they look for such that is a good solution to Problem 1. Naturally, such a satisfactory cannot be found for completely arbitrary and, rather, channel-agnostic methods focus on channels drawn from some distribution of interest. Moreover, due to the specialized nature of , channel-specific models tend to perform better for the particular channel but their performance quickly degrades when a different channel is drawn.

In this setting, we are motivated by the following question: Can we develop a generalizable channel-agnostic method that achieves performance comparable with channel-specific methods for a channel (or set of channels ) of interest? In essence, we want to keep the best of both classes of methods by performing close to optimal on prespecified channels while generalizing to a whole distribution. Our motivating question is relevant in practice when, e.g., the channel fading varies smoothly with time as in Jakes model [Nasir2019] (see Section 4 for more details). In such a case, we want our learning scheme to perform especially well around the current channel while generalizing satisfactorily to avoid the need for immediate retraining.

Figure 1: Our proposed solution seeks to combine the best of both classes of methods by specializing to a channel (or set of channels) of interest while generalizing to a whole channel distribution.

At a high level, given some metric in the space of channels, channel-specific solutions yield lower symbol error rate (SER) close to the channel for which they were trained whereas channel-agnostic methods work better when channels further away from are drawn; see Fig. 1. Intuitively, we want to derive a method that attains the behavior illustrated in green in Fig. 1. Hence, one can think of our sought solution as a robust version of a channel-specific method that gracefully degrades into a channel-agnostic method. Alternatively, one can see the envisioned solution as a channel-agnostic method that has been specially tuned to overperform on a subset of channels of interest. Either way, we propose to achieve this through the use of hypernetworks whose training is regularized by the solutions of channel-specific methods, as we detail next.

3 Hypernetworks with Learned Regularizers

In Section 3.1 we introduce the notion of a hypernetwork and its use in machine learning whereas in Section 3.2 we detail how we incorporate hypernetworks in our solution to Problem 1.

3.1 Hypernetworks in machine learning and MIMO detection

Hypernetworks are neural networks that have as output the weights of another target (or main) neural network, which performs the learning task [ha2016hypernetworks]; see Fig. 2. More precisely, the interpretation of our main network is that of a classical neural network that learns a parametric function from input data to some desired target , being the learnable parameters of this neural network. The goal of the hypernetwork, on the other hand, is to learn a parametric function from the (possibly different) input into the space of parameters of the main network. In this way, we are not fixing the parameters of our main network but rather making these a function of the input , effectively improving the generalizability of . Thus, given inputs , the output of the main network is given by . It should be noted that only the parameters of the hypernetwork need to be learned during training.

Figure 2: The general scheme of a hypernetwork. The hypernetwork takes the input and generates the weights , which are fed into the main network. Then, the main network takes as input and returns the output .

The notion of a hypernetwork has been used in several contexts such as object recognition [Bertinetto2016LearningFO] and generation of 3-D point clouds [spurek2020hypernetwork]. For example, hypernetworks have been used for 3D shape reconstruction [littwin2019deep] and to learn shared representations of images [sitzmann2020implicit]. In the specific context of MIMO detection, the use of hypernetworks has been already proposed in [hypermimo] applied to the MMNet. As the MMNet depends on a particular channel realization, the hypernetwork enables the generalization to a whole distribution of channels.

3.2 Learning hypernetwork regularizers

Having formally introduced the concept of a hypernetwork, we can now revisit the HyperMIMO [hypermimo], which exactly follows the framework in Fig. 2. In particular, we have that the main network – which takes the form of an MMNet [mmnet] – has as input the observation and the CSI, i.e., . Moreover, the hypernetwork takes the CSI as input and generates the weights for the multiple layers of the main MMNet. Then, the MMNet generates the estimate of the transmitted symbols (). The weights of the hypernetwork are trained to minimize a loss that compares with the true transmitted symbols . This flow is depicted by blue arrows in Fig. 3.

We expand the described training procedure to attain a solution to Problem 1 that captures the desirable behavior in Fig. 1. First, we determine the channel or set of channels on which we want our solution to achieve especially high detection performance. This choice will be guided by the nature of the system where we anticipate that our detector will be deployed. Given many realizations for the channels , we train a collection of MMNets , one per channel . Notice that, given the channel-specific nature of , the learned weights entail good detection performance for the channel . We use these pretrained weights as regularizers during the training of our hypernetwork; see red arrows in Fig. 3. To be more precise, if we denote by the weights output by the hypernetwork when we input channel , we define our regularized loss as

Figure 3: Scheme of our proposed training architecture. In addition to the more classical hypernetwork training (blue arrows), we propose a regularization term that depends on several pretrained main (MMNet) networks, as depicted by the red arrows.

The first term in (3) computes a classical mean square error between the true symbols and the estimated symbols, where the expected value is taken over the channel, input, and noise distributions of interest. The channel distribution used here is the one for which we want our channel-agnostic method to generalize. The second term penalizes the distance between the parameters in the pretrained MMNets and those generated by the hypernetwork when it is fed with channels from . We measure this discrepancy using an norm to promote a sparse difference between and . This means that the weights generated by the hypernetwork tend to coincide with for a subset of the entries. The relative weight captures the importance of performing well within the set . Indeed, when our proposed method boils down to HyperMIMO and completely ignores the prespecified channels . On the other hand, for (and assuming that the hypernetwork is sufficiently expressive) our method should mimic the behavior of MMNet on but quickly degrade for generic channels. By selecting an intermediate value of , we can realize the desired behavior in Fig. 1, as we demonstrate in Section 4. The model is trained by minimizing the loss in (3) with respect to the hypernetworks parameters

through stochastic gradient descent.

Before presenting the numerical experiments, two remarks are in order. First, although for concreteness we present our scheme in Fig. 3 for the case where the main network is an MMNet, the same framework can be applied for any generic channel-specific method taking the role of our main network. Second, although our proposed scheme incurs an additional training load in comparison with a vanilla hypernetwork, once trained their computational complexities for detection are exactly the same. We will refer to our proposed methodology particularized for the MMNet as HyperMIMO with learned regularizers, or HyperMIMO-LR for short.

Figure 4: (a) SER as a function of SNR for different detection methods evaluated in a set of channel sequences generated via (4). (b and c) SER as a function of time for SNR = and SNR = , respectively (the same legend as Fig. 3(a) holds).

4 Numerical Experiments

In this section we present the results of our proposed method222Code to replicate the numerical experiments can be found at We start by presenting the channel model, simulation setup, and neural network training process. Then, we present the experimental results and derive insights into the performance of the HyperMIMO-LR.

4.1 Channel model

The channel model is generated following the Jakes model [Nasir2019]. We express the small-scale Rayleigh fading component as a first-order complex Gauss-Markov process



are independent and identically distributed circularly symmetric complex Gaussian random variables. The initial matrix

is generated following the Kronecker correlated channel model


where and and are the spatial correlation matrices at the receiver and transmitter, respectively, generated according to the exponential correlation matrix model with a correlation coefficient [Loyka2001]

. In our model, the signal-to-noise ratio (SNR) is given by


For the experiments, SNRs between and are considered.

4.2 Implementation

Our simulation environment includes a base station with receiver antennas and

transmitting single-antenna users. We consider 4-QAM modulation. The architecture of the hypernetwork is composed of three dense layers: the first layer has the same number of units as the input, the second one has 100 units and the third one has the number of units matching the number of parameter that MMNet requires. For the MMNet, we use 6 layers. The activation function for all layers in the hypernetwork is an ELU function; the reason why using an ELU and not a ReLU resides on the nature of the

parameters, which can take negative values.

Training. We use a batch size of 100 channel matrices generated from (5). The training is performed using ADAM optimizer [kingma2017adam] with a reduce plateau scheduler: we compute the loss every 500 iterations and when the loss stopped improving, the learning rate is reduced by a factor of . We train for iterations333For HyperMIMO we followed the same scheme as in [hypermimo], changing only the lower limit to .. The value of in (3) was set to . For the regularizer, we generate different sequences of length following (4) with and starting from the same initial matrix from (5) with . In total, we use pre-trained MMNets.

4.3 Simulation results

For testing the performance of the detectors, we generate a test set of sequences of the same length from the same model in (4), also starting from .

We compare the SER achieved by HyperMIMO-LR with respect to the following methods: HyperMIMO, MMNet, DetNet with fixed and varying channel, MMSE and ML (using the Gurobi solver [gurobi]). The comparisons are shown in Fig. 3(a). The figure reveals that the performance of HyperMIMO-LR is closest to the optimal ML, and outperforms all the other methods, in particular both HyperMIMO and MMNet. It is particularly interesting to observe that while HyperMIMO-LR consistently outperforms the classical MMSE detector, HyperMIMO has a worse performance than MMSE. This is because the performance of the HyperMIMO decreases significantly when it is tested in perturbed versions of a channel from the distribution, while HyperMIMO-LR performs robustly in those unseen channels.

The performance of the detector as a function of the hops is represented in Figs. 3(b) and 3(c), for an SNR of and , respectively. We use the same test set as in the previous experiment. In both cases, we observe equal SERs at (initial matrix ) for both MMNet and HyperMIMO-LR. This is expected because the parameters of both architectures are similar due to the regularizer, and hence the performance of both has to be the same at the initial hop. We also see that the performance of DetNet-FC at the initial hop is close to HyperMIMO-LR and MMNet, but the performance for both MMNet and DetNet-FC quickly degrades as we increase . Moreover, HyperMIMO follows a similar trend as HyperMIMO-LR, meaning that its performance does not drop severely with but nonetheless it is inferior to HyperMIMO-LR. Overall, this behaviour is what we expected from our motivation defined in Fig. 1. Lastly, we see that MMSE performs relatively better for later hops . This can be explained by looking at the Jakes model in (4): as we get farther from , the Gaussian component tends to dominate. In such a Gaussian regime, MMSE achieves a very good performance.

5 Conclusions

We proposed a general deep learning based solution for MIMO detection that achieves a high performance for perturbations of a prespecified set of channels while generalizing to the whole distribution. This was done by regularizing the training of the hypernetwork to a deep learning-based detector with solutions for a set of specific channels using that detector. We evaluated this general architecture with an implementation that uses HyperMIMO, a hypernetwork-based solution that incorporates MMNet as its deep learning-based MIMO detector. We demonstrated that our implementation, named HyperMIMO-LR, generalizes well to the whole distribution of channels and outperforms HyperMIMO. Future work include extending to higher-order systems as well as higher-order modulation.