Supplementary code for the paper "SplitGuard: Detecting and MitigatingTraining-Hijacking Attacks in Split Learning"
Distributed deep learning frameworks, such as split learning, have recently been proposed to enable a group of participants to collaboratively train a deep neural network without sharing their raw data. Split learning in particular achieves this goal by dividing a neural network between a client and a server so that the client computes the initial set of layers, and the server computes the rest. However, this method introduces a unique attack vector for a malicious server attempting to steal the client's private data: the server can direct the client model towards learning a task of its choice. With a concrete example already proposed, such training-hijacking attacks present a significant risk for the data privacy of split learning clients. In this paper, we propose SplitGuard, a method by which a split learning client can detect whether it is being targeted by a training-hijacking attack or not. We experimentally evaluate its effectiveness, and discuss in detail various points related to its use. We conclude that SplitGuard can effectively detect training-hijacking attacks while minimizing the amount of information recovered by the adversaries.READ FULL TEXT VIEW PDF
Supplementary code for the paper "SplitGuard: Detecting and MitigatingTraining-Hijacking Attacks in Split Learning"
As neural networks, and more specifically deep neural networks
(DNNs), began outperforming traditional machine learning methods in tasks such as natural language processing, they became the workhorses driving the field of machine learning forward. However, effectively training a DNN requires large amounts of computational power and high-quality data . On the other hand, relying on a sustained increase in computing power is unsustainable , and it may not be possible to share data freely in fields such as healthcare [1, 15].
To alleviate these two problems, distributed deep learning methods such as split learning (SplitNN) [22, 6] and federated learning (FL) [2, 11, 12] have been proposed. They fulfill their purpose by enabling a group of data-holders to collaboratively train a DNN without sharing their private data, while offloading some of the computational work to a more powerful server.
In FL, each client trains a DNN using its local data, and sends its parameter updates to the central server; the server then aggregates the updates in some way (e.g. average) and sends the aggregated results back to each client. In SplitNN, a DNN is split into two parts and the clients train in a round-robin manner. The client taking its turn computes the first few layers of the DNN and sends the output to the server, who then computes the DNN’s overall output and starts the parameter updates by calculating the loss value. In both methods, no client shares its private data with another party, and all clients end up with the same model.
Motivation. In SplitNN, the server has control over the parameter updates being propagated back to each client model. This creates a new attack vector, that has already been exploited in an attack proposed by Pasquini et al. , for a malicious server trying to infer the clients’ private data. By contrast, this attack vector does not exist in federated learning, since the clients can trivially check if their model is aligned with their goals by calculating its accuracy. The same process is not possible in split learning, since the adversary can train a legitimate model on the side using the clients’ intermediate outputs, and use that model for a performance measure. In fact, any such detection protocol that expects cooperation from the server is doomed to failure through the server’s use of a legitimate surrogate model as described.
Our main contribution in this paper is SplitGuard, a protocol by which a SplitNN client can detect, without expecting cooperation from the server, if its local model is being hijacked. To the best of our knowledge, SplitGuard is the first attempt at detecting training-hijacking attacks against split learning clients. To achieve our goal, we utilize the observation that if a client’s local model is learning the intended task, then it should behave in a drastically different way when the task is reversed (i.e. when success in the original task implies failure in the new task). We demonstrate using three commonly used benchmark datasets (MNIST, Fashion-MNIST , and CIFAR10 ) that SplitGuard effectively detects and mitigates the only training-hijacking attack proposed so far . We further argue that it is generalizable to any such training-hijacking attack.
In the rest of the paper, we first provide the necessary background on DNNs and SplitNN, and explain some of the related work. We then describe SplitGuard, experimentally evaluate it, and discuss certain points pertaining to its use. We conclude by providing an outline of possible future work related to SplitGuard.
Supplementary code for the paper can be found at https://github.com/ege-erdogan/splitguard.
A neural network  is a parameterized function that tries to approximate a function . The goal of the training procedure is to learn the parameters using a training set consisting of examples and labels sampled from the real-world distributions and .
A typical neural network, also called a feedforward neural network, consists of discrete units called neurons
, organized into layers. Each neuron in a layer takes in a weighted sum of the previous layer’s neurons’ outputs, applies a non-linear activation function, and outputs the result. The weights connecting the layers to each other constitute the parameters that are updated during training. Considering each layer as a seperate function, we can model a neural network as a chain of functions, and represent it as, where corresponds to the first layer, to the second layer, and to the final, or the output layer.
Like many other machine learning methods, training a neural network involves minimizing a loss function. However, since the nonlinearity introduced by the activation functions applied at each neuron causes the loss function to become non-convex, we use iterative, gradient-based approaches to minimize the loss function. It is important to note that these methods do not provide any global convergence guarantees.
A widely-used optimization method is stochastic gradient descent
(SGD). Rather than computing the gradient from the entire data set, SGD computes gradients for batches selected from the data set. The weights are updated by propagating the error backwards using the backpropagation algorithm. Training a deep neural network generally requires multiple passes over the entire data set, each such pass being called anepoch. One round of training a neural network requires two passes through the network: one forward pass to compute the network’s output, and one backward pass to update the weights. We will use the terms forward pass and backward pass to refer to these operations in the following sections. For an overview of gradient-based optimization methods other than SGD, we refer the reader to .
In split learning (SplitNN) [6, 23, 22], a DNN is split between the clients and a server such that the clients compute the first few layers, and the server computes rest of the layers. This way, a group of clients can train a DNN utilizing, but not sharing, their collective data. Furthermore, most of the computational work is also offloaded to the server, reducing the training cost for the clients. However, this partitioning involves a privacy/cost trade-off for the clients, with the outputs of earlier layers leaking more information about the inputs.
Figure 1 displays the two basic modes of SplitNN, the main difference between the two being whether the clients share their labels with the server or not. In Figure 0(a), clients compute only the first few layers, and should share their labels with the server. The server then computes the loss value, starts backpropagation, and sends the gradients of its first layer back to the client, who then completes the backward pass. The private-label scenario depicted in Figure 0(b) follows the same procedure, with an additional communication step. Since now the client computes the loss value and initiates backpropagation, it should first feed the server model with the gradient values to resume backpropagation.
For our purposes, it is important to realize that the server can launch a training-hijacking attack even in the private-label scenario (Figure 0(b)). It simply discards the gradients it received from the second part of the client model, and computes a malicious loss function using the intermediate output it received from the first client model, propagating the malicious loss back to the first client model.
The primary advantage of SplitNN compared to federated learning is its lower communication load . While federated learning clients have to share their entire parameter updates with the server, SplitNN clients only share the output of a single layer. However, choosing an appropriate split depth is crucial for SplitNN to actually provide data privacy. If the initial client model is too shallow, an honest-but-curious server can recover the private inputs with high accuracy, knowing only the model architecture (not the parameters) on the clients’ side . This implies that SplitNN clients should increase their computational load, by computing more layers, for stronger data privacy.
Finally, SplitNN follows a round-robin training protocol to accomodate multiple clients; clients take turn training with the server using their local data. Before a client starts its turn, it should bring its parameters up to date with those of the most recently trained client. There are two ways to achieve this: the clients can either share their parameters through a central parameter server, or directly communicate with each other in a P2P way and update their parameters.
The Feature-Space Hijacking Attack (FSHA), by Pasquini et al. , is the only proposed training-hijacking attack against SplitNN clients so far. It is important to gain an understanding of how a training-hijacking attack might work before discussing SplitGuard in detail.
In FSHA, the atttacker (SplitNN server) first trains an autoencoder (consisting of the encoderand the decoder ) on some public dataset similar to that of the client’s private dataset . It is important for the attack’s effectiveness that be similar to . Without such a dataset at all, the attack cannot be launched. The main idea then is for the server to bring the output spaces of the client model and the encoder as close as possible, so that the decoder can successfully invert the client outputs and recover the private inputs.
After this initial setup phase, the client model’s training begins. For this step, the attacker initializes a distinguisher model that tries to distinguish the client’s output from the encoder’s output . More formally, the distinguisher is updated at each iteration to minimize the loss function
Simultaneously at each training iteration, the server directs the client model towards maximizing the distinguisher’s error rate, thus minimizing the loss function
In the end, the output spaces of the client model and the server’s encoder are expected to overlap to a great extent, making it possible for the decoder to invert the client’s outputs.
Notice that the client’s loss function is totally independent of the training labels, as in changing the value of the labels does not affect the loss function. We will soon refer to this observation.
We start our presentation of SplitGuard by restating an earlier remark: If the training-hijacking detection protocol requires the attacker SplitNN server to knowingly take part in the protocol, the server can easily circumvent the protocol by training a legitimate model on the side, and using that model during the protocol’s run. In the light of this, it is evident that we need a method which the clients can run during training, without breaking the flow of training from the server’s point of view.
During training with SplitGuard, clients intermittently input batches with randomized labels, denoted fake batches. The main idea is that if the client model is learning the intended task, then the gradient values received from the server should be noticeably different for fake batches and regular batches.111Fake gradients and regular gradients similarly refer to the gradients resulting from fake and regular batches.
The client model learning the intended task means that it is moving towards a relatively high-accuracy point on its parameter space. That same high-accuracy point becomes a low-accuracy point when the labels are randomized. The model tries to get away from that point, and the classification error increases. More specifically, we make the following two claims (experimentally validated in Section V-B):
If the client model is learning the intended task, then the angle between fake and regular gradients will be higher than the angle between two random subsets of regular gradients.222Angle between sets meaning the angle between the sums of vectors in those sets.
If the client model is learning the intended task, then fake gradients will have a higher magnitude than regular gradients.
|Probability of sending a fake batch|
|Share of randomized labels in a fake batch|
|Batch index at which SplitGuard starts running|
|Set of fake gradients|
|Random, disjoint subsets of regular gradients|
|Parameters of the SplitGuard score function|
|Number of classes|
|Model’s classification accuracy|
|Expected classification accuracy for a fake batch|
At the core of SplitGuard, clients compute a value, denoted the SplitGuard score, based on the fake and regular gradients they have collected up to that point. This value’s history is then used to reach a decision on whether the server is launching an attack or not. We now describe this calculation process in more detail. Table I displays the notation we use from here on.
Starting with the th batch during the first epoch of training, with probability ,333This is equivalent to allocating a certain share of the training dataset for this purpose before training. clients send fake batches with the share of the labels randomized. Upon calculating the gradient values for their first layer, clients append the fake gradients to the list , and split the regular gradients randomly into the lists and , where . To minimize the effect of fake batches on model performance, clients discard the parameter updates resulting from fake batches. Figure 2 displays a simplified overview of the protocol, and Algorithm 1 explains the modified training procedure in more detail. The MAKE_DECISION function contains the clients’ decison-making logic and will be described later in Algorithm 3.
To make sure that none of the randomized labels gets mapped to its original value, it is a good idea to add to each label a random positive integer between 1 and exclusive, and compute the result modulo , where is the number of classes.
We should first define two quantities. For two sets of vectors and , we define as the absolute difference between the average magnitudes of the vectors in and :
and as the angle between sums of vectors in two sets and :
for a set of vectors .
Going back to the two claims, we can restate them using these quantities:
If the model is learning the intended task, then it follows from the two claims that the product will be greater than the product . If the model is learning some other task independent of the labels, then , and will essentially be three random samples of the set of gradients obtained during training, and it will not be possible to consistently detect the same relationships among them.
We can now define the values clients compute to reach a decision. First, after each fake batch, the clients compute the value:
As stated, the numerator contains the useful information we want to extract, and we divide that result by , where is a small constant to avoid division by zero. This division bounds the value within the interval , a feature that will shortly come handy.
So far, the claims lead us to consider high S values as indicating an honest server, and low S values as indicating a malicious server. However, the S values obtained during honest training vary from one model/task to another. For a more effective method, we need to define the notions of higher and lower more clearly. For this purpose, we will define a squashing function that maps the interval to the interval , where high S values get mapped infinitesimally close to 1 while the lower values get mapped to considerably lower values.444From here on we will refer to the values very close to 1 as being equal to 1, since that is the case when working with limited-precision floating point numbers. This allows the clients to choose a threshold, such as , to separate high and low values.
Our function of choice for the squashing function is the logistic sigmoid function. To provide some form of flexibility to the clients, we introduce two parameters, and , and define the function as follows:
The function fits naturally for our purposes into the interval , mapping the high-end of the interval to 1, and the lower-end to 0. The parameter determines the range of values that get mapped very close to 1, while increasing the parameter punishes the values that are less than 1. We discuss the process of choosing these parameters in more depth in Section VI.
We need to answer three questions to claim that SplitGuard is an effective method:
How much does sending fake batches affect model performance? If the decrease is significant, then the harm might outweigh the benefit.
Do the underlying claims hold?
Can SplitGuard succeed in detecting FSHA, while not reporting an attack during honest training?
What can a typical adversary learn until detection?
In each of the following subsections, we answer one of these questions by conducting various experiments. For our experiments, we used the ResNet architecture , trained with the Adam optimizer , on the MNIST , Fashion-MNIST , and CIFAR10 
datasets. We implemented our attack in Python (v 3.7) using the PyTorch library (v 1.9). In all our experiments, we limit our scope only to the first epoch of training. It is the least favorable time for detecting an attack since the model initially behaves randomly, and represents a lower bound for results in later epochs.
|Classification Accuracy (%)|
Table II displays the classification accuracy of the ResNet model on the test sets of our three benchmark datasets with different values, averaged over 10 runs. The client model consists of a single convolutional layer, and the rest of the model is computed by the server. This is the worst-case scenario for this purpose, since the part of the model that is being updated with fake batches is as large as possible. Also remember that a value of 1 does not mean that the clients always send fake labels. They are still sending fake labels with probability .
The results demonstrate that even when limited to the first epoch, the model performs similarly when trained with and without SplitGuard. There is not a noticeable and consistent decrease in performance for any of the datasets, even for high values such as 1.
Going back to the two claims, we now demonstrate that fake gradients make a larger angle with regular gradients than the angle between two subsets of regular gradients, and that fake gradients have a higher magnitude than regular gradients. Figures 3 and 4 display these values for each of our three datasets obtained during the first epoch of training with an honest server, averaged over 5 runs.
From Figure 3, it can be observed that is consistently greater than for each of our benchmark datasets. Note however that the difference is greater for MNIST (around 60) than for Fashion-MNIST (around 30) and CIFAR (around 10). Remembering from Table II that the model’s performance after the first epoch of training is higher in MNIST compared to other datasets, it is not surprising that the difference between the angles is higher as well. As we will discuss later, SplitGuard is more effective as the model becomes more accurate.
Finally, Figure 4 displays a similar relation between the and values obtained during the first epoch of training. For each of our datasets, values are consistently higher than the values, although the difference is smaller for CIFAR compared to MNIST.
To recap, Figures 3 and 4 demonstrate that our claims are valid during the first epoch of training for our benchmark datasets. The decreasing difference as the models become less adept (going from MNIST to CIFAR10) implies that the protocol might need to be extended beyond the first epoch for more complex tasks.
With the claims validated, the questions of actual effectiveness remains: how well does SplitGuard defend against FSHA?
To show that SplitGuard can effectively detect a SplitNN server launching FSHA, we ran the attack for each of our datasets. Figure 5 displays the SplitGuard scores obtained during the first epoch of training by the clients against an honest server and a FSHA attacker, averaged over 5 runs. The value is set to , and the value varies.555Note that the value does not affect the SplitGuard scores obtained against a FSHA server, since the client’s loss function is independent of the labels. We experimentally set the and values to 5 and 2 respectively, representing reasonable starting points, although we do not claim that they are optimal values.
The results displayed in Figure 5 indicate that the SplitGuard scores are distinguishable enough to enable detection by the client. The SplitGuard scores obtained with an honest server are very close or equal to 1, while the scores obtained against a FSHA server do not surpass . Notice that higher values are expectedly more effective. For example, it takes slightly more time for the scores to get fixed around 1 for Fashion-MNIST with a of compared to a of 1. The same can be said for CIFAR10 as well, although it is evident that the value should be set higher.
To assess more rigorously how accurate SplitGuard is at detecting FSHA, and likewise not reporting an attack during honest training, we define three candidate decision-making policies with different goals and test each one’s effectiveness. A policy takes as input the list of SplitGuard scores obtained up to that point, and decides if the server is launching a training-hijacking attack or not. We set a threshold of 0.9 for these example policies. While the clients can choose different thresholds (Section VI-B), the results in 5 indicate that 0.9 is a sensible starting point. The three policies, also displayed in Algorithm 2 are defined as follows:
Fast: Fix an early batch index. Report attack if the last score obtained is less than after that index. The goal of this policy is to detect an attack as fast as possible, without worrying too much about a high false positive rate.
Avg-: Report attack if the average of the last scores is less than . This policy represents a middle point between the Fast and the Voting policies.
Voting: Wait until a certain number of scores is obtained. Then divide the scores up to a fixed number of groups, calculate each group’s average, and report attack if the majority of the mean values is less than . This policy aims for a high overall success rate (i.e. high true positive and low false positive rates). It can tolerate making decisions relatively later.
Note that these policies are not conclusive, and are provided as basic examples. More complex policies can be implemented to suit different settings. We will discuss the clients’ decision-making process in more detail in Section VI-B.
Table III displays the detection statistics for each of these strategies obtained over 100 runs of the first epoch of training with a FSHA attacker and an honest server with a of 1 and of 0.1. For the Avg-k policy, we use values of 10, 20, and 50, corresponding to roughly 100, 200, and 500 batches with a of 0.1; this ensures that the policy can run within the first training epoch.666With a batch size of 64, one epoch is equal to 938 batches for MNIST and F-MNIST, and 782 for CIFAR10. For the Voting policy, we set the group size to 5 and the group count to 10, again corresponding to around 500 batches with 0.1. Finally, we set , the index at which SplitGuard starts running, as 20 for MNIST and F-MNIST, and 50 for CIFAR10.
A significant result is that all the strategies achieve a perfect true positive rate (i.e. successfully detect all runs of FSHA). Expectedly, the Fast strategy achieves the fastest detection times as denoted by the values in Table III, detecting in less than a hundred training batches all instances of the attack.
Another important observation is that the false positive rates increase as the model’s performance decreases, moving from MNIST to F-MNIST and then CIFAR10. This means that more training time should be taken to achieve higher success rates in more complex tasks. This is not a troubling scenario, since as we will shortly observe the model not having a high performance also implies that the attack will be less effective. Nevertheless, the Voting policy achieves a false positive rate of 0 for (F-)MNIST and 0.02 for CIFAR, indicating that despite the relatively high false positive rates of the Avg- policies, better detection performance in less time is achievable through smarter policies.
We now analyze what a FSHA adversary can obtain until the detection batch indices displayed in Table III. Figure 6 displays the results obtained by the attacker after the batch indices corresponding to the detection times of the given policies. Note that all these batch indices fall within the first training epoch.
It is visible that for the Fast policy, the attacker obtains not much more than random noise. This means that if a high false positive rate can be tolerated (e.g. privacy of the data is highly critical, and the server is distrusted), this policy can be applied to prevent any data leakage.
Unsurprisingly, the attack results get more accurate as the attacker is given more time. Nevertheless, especially for the more complex CIFAR10 task, the results obtained by the attacker against the Voting policy do not contain the distinguishing features of the original images. This highlights the effectiveness of the Voting policy, preventing significant information leakage with a relatively low false positive rate of 0.02. We would like to note once again that the policies described above are rather simplistic, and do not use the clients’ full power, as will be discussed in Section VI-B.
Finally, the CIFAR10 results also give credibility to our previous statement that giving more time to the attacker for a more complex task should not be a cause of worry. After the same number of batches, the attacker’s results for MNIST and Fashion-MNIST are more accurate than the CIFAR10 results.
In this section, we answer the following questions:
What is the computational complexity of running SplitGuard (for clients)?
How can the clients make a decision on whether the server is honest or not?
Can the attacker detect SplitGuard? What happens if it can?
What effect do the parameters have on the system?
Can SplitGuard generalize to different scenarios?
What are some concrete use cases for SplitGuard?
We now argue that SplitGuard does not incur a significant computational cost regarding time or space. Since SplitNN clients are already assumed to be able to run back-propagation on a few DNN layers, calculating the S value described in Equation 6 is a simple task.
Space-wise, although it might seem like storing the gradient vectors for potentially multiple epochs requires a significant amount of space, the clients in fact do not have to store all the gradient vectors. For each of the sets , , , the clients have to maintain two quantities: a sum of all vectors in the set, and the average magnitude of the vectors in the set; the first has the dimensions of a single gradient vector, and the second is a scalar. More importantly, both of these quantities can be maintained in a running manner. This keeps the total space required by SplitGuard to with respect to training time, equivalent to the space needed for three scalar values and three gradient vectors. For reference, the space required to store a single gradient vector in our experiments was 2.304 KB. Since the space requirement is independent of the total number of batches, it is possible to run SplitGuard for arbitrarily long training processes.
We have described some makeshift decision-making policies in Section V-C, and in this subsection we discuss the clients’ decision-making process in more depth, without focusing on a specific policy.
After each fake batch, clients can make a decision on whether the server is launching an attack or not. The main decision procedure is as follows: [enumerate] Is the SG value high or low? If high, there are no problems. Keep training. If low, there are two possible explanations: The model has not learned enough yet. Keep going, potentially making changes. The server is launching an attack. Halt training. Going back to the policies we have described in Section V-C, it can be seen that they did not consider the first explanation (1.b.i) of low scores, namely the model not having learned enough. As we will see, taking that into consideration could help reduce the false positive rates.
The outline contains two branching points: separating high and low scores, and explaining low scores.
Separating High and Low Scores.
The process of separating high and low SplitGuard scores consists of two steps: setting the hyperparameters of the squashing function, and deciding on a threshold value in the interval. We consider two scenarios: the clients know or do not know the server model architecture.
If the clients know the architecture, then the clients can train the entire model using all or part of their local data, and gain a prior understanding of what S values (Equation 6) values to expect during honest training. The parameters and can then be adjusted to map these values very close to 1. In this scenario, since the clients’ confidence on the accuracy of the method is expected to be higher, a relatively high threshold can be set, such as .
If the clients do not know the model architecture, then they should set the parameters and manually. Nevertheless, S values all lying within the interval makes the clients’ job easier. It is unreasonable to set extremely high or values since they will cause the squashing function to make sudden jumps, or map no value close to one. As our experiments also demonstrate, smaller values such as 5 and 2 are reasonable starting points.
Finally, note that the clients do not have to decide based on a single SplitGuard score. They can consider the entire history of the score, as depicted in Figure 5 and done in the Avg-k and Voting policies. For example, the score making a sudden jump to 1 and shortly going down to 0.5 does not imply honest training; similarly, the score making a sudden jump down to 0.5 after consistently remaining close to 1 does not strictly imply training-hijacking.
Explaining Low Scores. When a client decides that the SplitGuard score is low, it should choose between two possible explanations: either the model has not learned enough yet, or the server is launching a training-hijacking attack.
Informally, a low score indicates that fake gradients are not that different from regular gradients; the model behaves similarly when given fake batches and regular batches. In the domain of classification, behaving similarly is equivalent to having a similar classification accuracy. Then, the explanation that the model has not learned enough yet is more likely if the expected classification accuracy for a fake batch is close to the actual (expected) prediction accuracy. If these values are different but the SplitGuard score is still low, then the server is very likely launching an attack.
We can formulate the expected accuracy for a fake batch. Say the total number of labels is and the overall model has classification accuracy . Then the expected classification accuracy for a fake batch with the share of the labels randomized is
Figure 7 explains this equation visually.
If the model terminates on the client-side (as in Figure 0(b)
), then the clients already know the exact accuracy value. If that is not the case but the clients know the model architecture on the server side, then they can train the model using their local data, and obtain an estimate of the expected classification accuracy of the actual model during the first epoch. If even that is not possible, then in the worst-case the clients can train a linear classifier appended to their model to obtain a lower bound on the original model accuracy.777A related, interesting study concludes that what a neural network learns during its initial epoch of training can be explained by a linear classifier , in the sense that if we know the linear model’s output, then knowing the main model’s output provides almost no benefit in predicting the label. Note however that this does not hold for any linear classifier, but the optimal one.
Formalizing this discussion, for SplitGuard to be effective, it must be the case that . If , then the clients’ choice of is not right, and they should increase it. Note that is a linear function of with the coefficient
as well. Thus, is indeed a monotonic function of , and increasing either keeps constant or decreases it. Then when the clients decide that the SG value is low and that , the best course of action is to increase . If is already 1, then clients should wait until the model becomes sufficiently accurate so that a completely randomized batch makes a difference. Note that as discussed previously, this is not a worrisome scenario, since the attack’s effectiveness also relies on the model’s adeptness.
Finally, an alternative course of action is to increase , discarding the initial group of gradients. Since the models behave randomly in the beginning, increasing decreases the noise, and can help distinguish an honest server from a malicious one. Also note that increasing is a reversible process, provided that clients store the gradient values.
With these discussions, we can finalize the clients’ decision-making process as the function MAKE_DECISION, displayed in Algorithm 3.
An attacker can in turn try to detect that a client is running SplitGuard. It can then try to circumvent SplitGuard by using a legitimate surrogate model as described before.
If the server controls the model’s output (Figure 0(a)), then it can detect if the classification error of a batch is significantly higher than the other ones. Since SplitGuard is a potential, though not the only, explanation of such behavior, it presents an opportunity for an attacker to detect it. However, the model behaving significantly differently for fake and regular batches also implies that the model is at a stage at which SplitGuard is effective. This leads to an interesting scenario: since the attack’s and SplitGuard’s effectiveness both depends on the model learning enough it seems as if the attack cannot be detected without the attacker detecting SplitGuard and vice versa.
We argue that this is not the case, due to the clients being in charge of setting the value. For example, with the MNIST dataset for which the model obtains a classification accuracy around after the first epoch of training, a value of results in an expected classification accuracy of for fake batches (Equation 7). The SplitGuard scores on the other hand displayed in Figure 5 being very close to one implies that an attack can be detected with such a value. Thus, clients can make it difficult for an attacker to detect SplitGuard by setting the value more smartly, rather than setting it blindly as 1 for better effectiveness.
Finally, we strongly recommend once again that a secure SplitNN setup follow the three-part setup shown in Figure 0(b) to prevent the clients sharing their labels with the server. This way, an attacker would not be able to see the accuracy of the model, and it would become significantly harder for it to detect SplitGuard.
We have touched upon how the clients can decide on , , and values, but we need to clarify the effects of the parameter values (mainly , , and ) for completeness. Each parameter involves a different trade-off:
Probability of sending a fake batch ().
() Higher values mean more fake batches, and thus a more representative sample of fake gradient values, increasing the effectiveness of the method.
() Higher values can also degrade model performance, since the server model will be learning random labels for a higher number of examples, and a higher share of the potentially scarce dataset will be allocated for SplitGuard.
Number of randomized labels in each batch ().
() More random labels in a batch means that fake batches and regular batches behave even more differently, and the method becomes more effective.
() Depending on the model’s training performance, batches with entirely random labels can be detected by the server. One way to overcome this difficulty is to perform the loss computation on the client side.
Number of initial batches to ignore ().
() A smaller value means that the server’s malicious behavior can be detected earlier, giving it less time to attack.
() Since a model behaves randomly in the beginning of the training, the initial batches are of little value for our purposes. Computing SG scores for later batches will make it easier to distinguish honest behavior, but in return give the attacker more time.
In the form we have discussed so far, a question might arise regarding SplitGuard’s effectiveness in different scenarios. We argue however that since the claims underlying SplitGuard are applicable to any kind of neural network learning on any kind of data, SplitGuard is generalizable to different data modalities, or more complex architectures. The only caveat, as discussed earlier, is that learning on a more complex dataset or with a more complex architecture would require more time for SplitGuard to be effective.
Another direction of generalization is towards different attacks. Although there are no training-hijacking attacks other than FSHA against which we can test SplitGuard, we claim that SplitGuard can generalize to future attacks as well. After all, SplitGuard relies only on the assumption that randomizing the labels affects an honest model more than it affects a malicious model. Thus, to go undetected by SplitGuard, an attack should either involve learning significant information about the original task, which would likely reduce the attack’s effectiveness, or craft a different loss function for each label, which could easily be prevented by not sharing the labels with the server (Figure 0(b)).
Finally, SplitGuard also generalizes to multiple-client SplitNN settings. Each client can independently run SplitGuard, with their own choices of parameters. Each client would then be making a decision regarding its own training process. Alternatively, if the clients trust each other, they can choose one client to run SplitGuard in order to minimize its effect on performance loss, or they can combine their collected gradient values and reach a collective decision.888This is similar to the Voting policy described earlier, where the separation of scores into groups follows naturally from their distributions among the clients. The latter scenario would be equivalent to a single client training with the aggregated data of all the clients.
We now describe three potential real-world use cases for SplitGuard, modeling clients with different capabilities at each scenario.
Powerful Clients. A group of healthcare providers decide to train a DNN using their aggregate data while maintaining data privacy. They decide on a training setup, and establish a central server.999Alternatively, members of the group can take turns acting as the SplitNN server in a P2P manner. Each client knows the model architecture and the hyperparameters, and preferably has access to the model’s output as well (no label-sharing). The clients can train models using their local data to determine the parameters and . Each client can then run SplitGuard during their training turns and see if they are being attacked. This is an example scenario with the clients as powerful as possible, and thus represents the optimal scenario for running SplitGuard.
Intermediate Clients. The SplitNN server is a researcher, attempting to perform privacy-preserving machine learning on some private dataset of some data-holder (the client). The researcher designs the training procedure, but the data-holder actively takes part in the protocol. The data-holder thus has tight control over how its data is organized. The client cannot train a local model since it does not know the entire architecture, and should set the parameters and manually. Nevertheless, it can easily run SplitGuard by modifying the training data being used in the protocol.
Weak Clients. An application developer is the SplitNN server, and the users’ mobile devices are the clients with private data. The clients do not know the model architecture, and cannot manipulate how their data is shared with the server. The application developer is in control of the entire process from design to execution. In this scenario, SplitGuard should be implemented at a lower-level, such as the ML libraries the mobile OS supports. However, even in that scenario, the application developer can implement a machine learning pipeline from scratch, without relying on any libraries. This is not an optimal scenario for running SplitGuard. There would have to be strict regulations, as well as gatekeeping by the OS provider (e.g. mandating that machine learning code must use one of the specified libraries) before SplitGuard could effectively be implemented for such clients.
We outline three possible avenues of future work related to SplitGuard: providing practical implementations, improving its robustness against detection by the attacker, and developing potentially undetectable attacks.
Although we provide a proof-of-concept implementation of SplitGuard, it should be readily supported by privacy-preserving machine learning libraries, such as PySyft . That way, SplitGuard can be seamlessly integrated into the client-side processes of split learning pipelines.
As we have explained in Section VI-C, SplitGuard can potentially, although unlikely, be detected by the attacker, who can then start sending fake gradients from its legitimate surrogate model and regular gradients from its malicious model. This could again cause a significant difference between the fake and regular gradients, and result in a high SplitGuard score. However, a potential weakness of this approach by the attacker is that now the fake gradients result from two different models with different objectives. Suppose the attacker detects SplitGuard at the 200th batch, and starts using its legitimate model. Then the fake gradients within the first 200 batches will be computed using a malicious model, and those after the 200th batch will be computed using the legitimate model. Clients can potentially detect this switch in models, and gain the upper hand. This is another point for which future improvement might be possible.
Finally, turning the tables, it might be possible to modify the existing attacks, or propose novel attacks to produce high SplitGuard scores, very likely at the cost of effectiveness. This represents another line of future work concerning SplitGuard.
In this paper, we presented SplitGuard, a method for SplitNN clients to detect if they are being targeted by a training-hijacking attack  or not. We described the theoretical foundations underlying SplitGuard, experimentally evaluated its effectiveness, and discussed at depth many issues related to its use. We conclude that when used appropriately, and in a secure setting without label-sharing, a client running SplitGuard can successfully detect training-hijacking attacks and leave the attacker empty-handed.