The advancements of the last few years in medical image segmentation were dominated by deep learning (DL) approaches. DL mostly eliminated the need for handcrafting image features. However, it has been arguably replaced by the need of domain experts to design application-specific DL models. In particular, the medical image computing field has been dominated by popular hand-engineered network architectures such as 2D and 3D U-Net [ronneberger2015u, cciccek20163d], V-Net [milletari2016v], High-Res-Net [li2017compactness], DeepMedic [kamnitsas2016deepmedic], and many others. To get a good network design for a particular problem, one promising direction is to automate the time-consuming model designing process via AutoML techniques. As another major challenge in model development, large amounts of data covering sufficient large range of examples are usually required to train accurate and robust models. To achieve this goal, hospitals and medical institutes often need to collaborate and host centralized databases for the development of clinical-grade DL models. This can become challenging due to data-privacy and various ethical concerns associated with data sharing in healthcare domain. One approach to combat such issues is through federated learning (FL), where only model and/or DL workflow parameters are shared among participating institutes instead of raw medical data. Furthermore, it is well known that global robustness and local accuracy is in many cases conflicting: models trained on large centralized datasets might not always generalize well to the data at a particular imaging site due to various inconsistencies (scanner models, imaging protocols, patient populations, etc.) among the different sites. In this case, domain adaptation (DA) is often needed. In this work, we propose to systematically tackle the three challenges in a unified framework: combining an FL algorithm with AutoML and the capability of global-local model adaptation. In particular, we implement a “supernet” training strategy that can be trained in a federated setting. We believe AutoML and FL technologies are a natural fit for each other because of their complementary nature. By combining the two, we are also able to address the DA problem. For one, FL can circumvent the problem of hosting and accessing large centralized datasets by distributing the learning effort to several clients with their own local data. FL will only communicate the model gradients after a local round of training to a centralized server which aggregates the results and starts the next round of FL. At the same time, AutoML with supernet design allows us to avoid hand-engineering of dataset-specific network architectures and a particular sub-network of the trained supernet can be used as a way of local domain adaptation to handle inconsistencies between the different contributing data sites. Next, we summarize related works.
Recently, deep learning is applied for various applications, such as image recognition, semantic segmentation, object detection, natural image generation, etc.. However, for each specific task, particular network architectures often need to be hand-designed. Neural architecture search (NAS) [elsken2018neural]
is one of the most common approaches to circumvent such hand-design of architectures in AutoML for DL applications. The goal of NAS is to automatically design neural network architectures without any human heuristics or assumptions. In addition to the model weights, after searching, the model architecture itself is optimized for the task at hand, while often still being generalizable to other datasets[yu2019c2fnas]. A common concept in (one-shot) NAS and AutoML literature is the “supernet” [liu2018darts, cai2018proxylessnas, you2020greedynas]
. The main idea behind supernet is that we can create a large neural network including several candidate modules at each level of the networks. This supernet can be trained jointly, and from the supernet, specific sub-networks can be chosen by selecting a path through the module candidates. At deployment, the final architecture is selected from the supernet by assigning path weights to select particular module candidates. Additional budgeting constraints, such as latency or number of model parameters, can be added to find optimal architectures for a given application. Recent works can achieve state-of-the-art results on computer vision tasks while being computationally efficient[tan2019efficientnet].
FL enables collaborative and decentralized DL training without sharing raw patient data [mcmahan2016communication]. Each client in FL trains locally on their own data and then submits their model parameters to a server that accumulates and aggregates the model updates from each client. Once a certain number of clients have submitted their updates, the aggregated model parameters are redistributed to the clients for local model update, and a new round of local training starts. While out of the scope of this work, FL can also be combined with additional privacy-preserving measures to avoid potential reconstruction of training data through model inversion if the model parameters would be leaked to an adversary [li2019privacy]. Several works have shown the applicability of FL to medical imaging tasks [sheller2018multi, li2019privacy, YANG2021101992]. Recent work that combines NAS approaches with FL has been proposed for the mobile phone applications [zhu2020real]. As such, its focus is on reducing the computational requirements on the local edge devices, making its setting quite different from the “cross-silo” FL [kairouz2019advances] medical image segmentation investigated here, where the focus is on model performance and personalization. The closest work in motivation to ours is [he2020towards]
which focuses on the non-I.I.D. setting but is restricted to using toy datasets for classification tasks, like CIFAR-10, and differs in its implementation details.
Domain adaptation aims to tackle data inconsistencies among different domains, or between training data and unknown data. In its simplest form, fine-tuning, also known as transfer learning[shin2016deep], can help to adapt a pre-trained model to a particular target domain. More recent approaches for DA typically involve some form of adversarial learning to introduce a specific loss that can minimize the feature-level differences among different domains [kamnitsas2017unsupervised] or through gradient back-propagation using adversarial training [ganin2015unsupervised, ganin2016domain]
. An alternative approach is coming from the “image translation” field where generative adversarial networks (GAN) are utilized to translate the image of one domain to mimic another domain. An important part of these approaches is the application of some form of cycle-consistency which is essential to train on un-paired data[isola2017image, zhu2017unpaired, zhang2018task] The common concept of adversarial training suggests that the gradients from external constraints will help balance various domains and change the model’s feature representations.
Our proposed approach here is similar in that we will ultimately adapt the model’s internal feature representations through the selection of an adapted sub-network of the trained supernet, but without the need to use computationally expensive adversarial learning schemes. Our contributions can be summarized as follows:
We show that we can successfully train models through federated learning with comparable or better performance to models trained on centrally hosted data.
We extent federated learning by introducing an AutoML approach for supernet model training.
We show that finding an optimal path through the supernet can act as a form of local domain adaptation and bring performance gains for each individual client.
Here we describe the technical details of the FL and AutoML approach utilized in this work. The proposed method can be separated into two steps: 1) FL with AutoML supernet training and 2) local model adaptation by finding the best path through the supernet with respect to the local data. Both FL and AutoML procedures presented are designed for 3D medical image segmentation tasks.
Client-Server-Based Federated Learning:
In its typical form, FL utilizes a client-server setup. Each client trains the same model architecture locally on their own data. Once a certain number of clients finished local training, the updated model weights (or their gradients) are sent to the server for aggregation. After aggregation, the new weights on the server are re-distributed to the clients to execute the next round of local model training. After several FL rounds, the models at each client are converged. Each client can be allowed to select their local best model by monitoring a certain performance metric on a local hold out validation set. In our experiments, we implement the FederatedAveraging algorithm proposed in [mcmahan2016communication]
. While there exist variants of this algorithm to address particular learning tasks, in its most general form, FL tries to minimize a global loss functionwhich can be a weighted combination of local losses that each is computed on a client ’s local data. Hence, FL can be formulated as the task of finding the model parameters that minimize given some local data .
where denote the weight coefficients for each client , respectively. Note, that the local data is never shared among the different clients. Only the model weights are accumulated and aggregated on the server as shown in Algorithm 1.
AutoML with Supernet:
In order to allow for personalized neural architectures, we designed a supernet consisting of various DL module candidates suitable for 3D medical imaging tasks shown in Fig. (a)a. Each candidate is a subgraph , denoted as with model weights . These modules are optimized at multiple resolution levels to capture different levels of image features useful for the segmentation task. In general, we follow the popular encoder-decoder structure which has been successfully applied to many medical imaging tasks [ronneberger2015u, cciccek20163d, milletari2016v] as shown in Fig. (b)b with skip connections that concatenate features of the encoder with their corresponding layer in the decoder path. During training, we choose arbitrary paths from the module candidates following a uniform sampling scheme (see Fig. (c)c) to define a sub-network sampled from the supernet as in Eq. 2.
In this work, we choose the combination of Dice loss [milletari2016v] and cross entropy loss as our loss function which is commonly used for segmentation tasks in medical imaging [isensee2021nnu]. Dice loss’ major advantage is its ability to work well in segmentation tasks with an unbalance in the amount of foreground/background regions. Once the supernet is trained, we can find a sub-network by identifying a locally optimal path through the supernet, effectively adapting the model to the target domain. During adaptation, the model parameters
stay fixed and only the path weights are optimized for one epoch on the local validation set. This results in an optimal paththat defines our locally adapted sub-network as Eq. 3.
3 Experiments & Results
Our proposed method is evaluated on the task of 3D whole prostate segmentation in T2-weighted MRI. In particular, MRI has challenges of data inconsistencies due to variations in different imaging protocols and scanners used at each data contributing site, potentially causing drastic variations in contrast and intensity values.
We utilize prostate MRI datasets from four different publicly available data sources. MSD-Prostate111http://medicaldecathlon.com [simpson2019large], PROMISE12 222https://promise12.grand-challenge.org [litjens2014evaluation], NCI-ISBI13333http://doi.org/10.7937/K9/TCIA.2015.zF0vlOPv, and ProstateX444https://prostatex.grand-challenge.org [litjens2014computer]. For each dataset, we perform three random splits into training, validation, and testing sets at roughly 70%, 10%, and 20% of the total number of cases of each dataset. The resulting number of cases for each dataset are shown in Table 1. We average results across the testing splits of each random split. For reference, we show the results on a centralized dataset where all four datasets have been combined. We also compare the performance for models trained locally and through federated learning an each dataset’s testing split. The performance of a standard 3D U-Net [cciccek20163d] which is a subgraph of our supernet (when all candidates are type 1) is shown for a baseline comparison. We resample each image to a constant resolution of 0.5 mm 0.5 mm 1.0 mm and normalize all non-zero image intensities by subtracting their mean and dividing by their standard derivation on a per-image basis.
Both U-Net and the supernet are trained using randomly cropped patches of size from the input images and labels. We used a mini-batch size of 18 by selecting 3 random crops from any 6 random input image and label pairs. As the optimizer for training the supernet, we chose NovoGrad which has typically faster convergence speed than the more commonly used Adam optimizer [ginsburg2019stochastic]. The learning rate for supernet training was set to . For finding the optimal path for the final sub-network we use the Adam optimizer with a learning rate of . Augmentation techniques like random intensity shifts, contrast adjustments, and adding Gaussian noise are applied during training to avoid overfitting to the training set. Our supernet has possible path combinations. Therefore, it is trained 10
longer than 3D U-Net to give it the opportunity to train most paths well. Both 3D U-Net baseline and the supernet are implemented with PyTorch555https://pytorch.org using components from MONAI666https://monai.io and NVFlare 777https://pypi.org/project/nvflare
for FL communication. All models are trained on NVIDIA V100 GPUs with 16 GB memory. We monitor convergence on randomly chosen paths sampled from a uniform distribution during each validation to determine when the supernet is sufficiently trained across clients. The number of training iterations is chosen such that the likelihood of a path being selected during the entire training is at least ¿1.
shows the better generalizibilty of supernet models trained in the FL. We show the performance of the proposed supernet training approach and its adaption to the local dataset distribution via path optimization, together with a baseline implementation of 3D U-Net using the same augmentation, optimization and hyperparameters to be comparable. Visualization of the results before and after model adaptation are shown in Fig.5. In descending order, most commonly chosen operations were 3D conv., 3D residual block, 2D conv., followed by identity.
|Avg. Dice [%]||Central||NCI||PROMISE12||ProstateX||MSD||Avg. (loc.)|
|SN (loc.) + adapt.||90.15||90.50||83.46||90.78||87.83||88.14|
|SN (fed.) + adapt.||90.68||86.15||90.65||88.74||89.06|
|SN (loc.)||SN (fed.)||SN (fed.) + adapt|
|Test site||Test site||Test site|
4 Discussion & Conclusions
It can be observed from Table 1 that the supernet training with local adaptation in FL (SN (fed.) + adapt.) achieves the highest average Dice score on the local datasets. At the same time, the adapted models also show the best generalizability (see Table 2). This illustrates the viability of supernet training with local model adaption to the client’s data. We furthermore observe a general improvement of the local supernet models’ performance when trained in an FL setting versus local training. This means that in particular supernet model training can benefit from the larger effective training set size made available through FL without having to share any of the raw image data between clients. Overall, we achieve average Dice scores comparable to recent literature on whole prostate segmentation in MRI [litjens2014computer, litjens2014evaluation, milletari2016v] and can likely be improved with more aggressive data augmentation schemes [zhang2020generalizing, isensee2021nnu]. Further fine-tuning of the network weights (not the supernet path weights) is likely going to give performance boost on a local client but is also expected to reduce generalizability of the model. Methods of fine-tuning that do not reduce the robustness to other data sources (i.e. generalizability) gained through FL (e.g. learning without forgetting [li2017learning]) is still an open research question and was deemed to be out of scope of this work.
In conclusion, we proposed to combine the advantages of both federated learning and AutoML. The two techniques are complementary and in combination, they allow for an implicit domain adaptation through the finding of locally optimal model architectures (sub-networks of the supernet) for a client’s dataset. We showed that the performances of federated learning are comparable to the model’s performance when the dataset is centrally hosted. After local adaptation via choosing the optimal path through the supernet, we can see an additional performance gain on the client’s data. In the future, it could be explored if there is a set of optimal sub-networks that could act as an ensemble during inference to further improve performance and provide additional estimates such as model uncertainty. Furthermore, one could adaptively change the path frequencies used during supernet training based on sub-network architectures that work well on each client in order to reduce communication cost and speed-up training.