In this post, I will leave some notes about Verifiable Federated Learning CLI Demo I have prototyped.
In a nutshell, this system allows mutually distrusted parties (let’s say nodes in a decentralized network like Ethereum) to privately contribute their data and collaboratively train an AI model with public verifiability.
In a separate post, I will share more about my vision for this technology with some fitting use cases in my mind.
Table of Contents
Architecture Overview
1.0 Protocol steps
1.1 Local training in clients
1.2 Masking models in clients
1.3 Aggregation in server
1.4 Notes on fixed-point arithmeticFuture Research Direction
2.1 Training dataset validation
2.2 ZK circuit for more complex training algorithm… or fine-tuning?
2.3 Storing local model updates offchain
2.4 Batched/Packed secret sharing for masking models
2.5 Clients dropouts tolerance + real-time join
2.6 Reward system
Architecture Overview
In my CLI demo, there are 4 parties simulated: - 3 distributed clients that locally train a model on their own raw data - 1 server that aggregates the local models submitted by clients
And I did not assume a P2P connection between them, so they communicate via blockchain.
Protocol steps
- Clients locally train a model on their data using some agreed training algorithm. While training, they generate a proof πtrain for the correct execution of the training algorithm on their secret inputs.
- Clients generate their own public and private key pair (skn,pkn). They submit pkn along with the proof πtrain to blockhain and after the proof verifies, their public keys are registered on chain, which signals their participation to this collaborative training system.
- Clients generate a pair-wise shared mask mright
and mleft
with their neighboring clients
neighbor_right
andneighbor_left
. Clients use their own private key skn and the neighbor’s public keys pkn + 1 and pkn − 1, respectively (I’m explaining the algorithm in more detail in section 2. Masking models in clients). After masking their model parameters (weights and bias - I’ll just call them a “model” from now on), they publish the masked model Mn on chain, along with zkp πmask which proves that they executed training and masking process correctly. - After the masking proof πmask verifies, server fetches the masked models to aggregate them. During the aggregation, server generates a proof πagg for the correct execution of aggregation, and after the proof πagg verifies, server registers the aggregated global model Mg on chain.
- Clients can fetch the global model Mg from blockchain.
sequenceDiagram
participant Client_n
participant Blockchain
participant Server
Client_n-->>Client_n: Train local model R_n, Generate training proof π_train_n
Client_n->>Blockchain: Submit (π_train_n + public key pk_n)
Blockchain-->>Blockchain: if π_train_n verified, then pk_n registered
Client_n->>Blockchain: Fetch pk_{n+1} (right neighbor) and pk_{n-1} (left neighbor)
Client_n-->>Client_n: Locally compute shared masks m_right_n=sk_n*pk_{n+1}, m_left_n=sk_n*pk_{n-1},<br>Mask the model: R_n + m_right_n - m_left_n, Generate masking proof π_mask_n
Client_n->>Blockchain: Submit masked model M_n + proof π_mask_n
Blockchain-->>Blockchain: if π_mask_n verified, then M_n registered
Server->>Blockchain: Fetch masked models M_n for all n
Server-->>Server: Aggregate local models, <br> Generate aggregation proof π_agg
Server->>Blockchain: Submit global model M_g + proof π_agg
Blockchain-->>Blockchain: if π_agg verified, then M_g registered
Client_n->>Blockchain: Fetch global model M_g
1. Local training in clients
There are client1-3, and each client locally trains a model using
their raw data (for example client1 has input data that looks like this)
inside ZK with logistic regression algorithm. The
Noir circuit for logistic regression was implemented by Hashcloak
for their noir-mpc-ml
project (their project report is here), and
I’ve imported this circuit for the local training. Their approach,
co-snark can be taken as an alternative to achieve the same goal as
mine, but since my approach offloads the training process to clients and
it does not require running it inside MPC, mine is more efficient. (It
just runs the training algorithm inside ZK, in order to give a public
verifiability.)
For example, Hashcloak has obtained the results below (with 3 MPC nodes):
(reference: Benchmarks)
Whereas my system has obtained (with 3 clients and 1 server all simulated wih Apple M1 16GB):
Epochs | Train samples | Training time [sec] |
---|---|---|
10 | 30 | 31 |
10 | 50 | 44 |
20 | 30 | 55 |
20 | 50 | 96 |
2. Masking models in clients
This is a cryptographically interesting part. After training a model, clients submit the model so that a server can aggregate them to build a global model. However, they cannot just publish a raw model, since there is a multitude of attacks you can make to reconstruct the training data, just by looking at the weights (keywords: “Gradient Inversion Attack”, “Membership Inference Attack”, “Property Inference Attack”).
So how can you mask your model, in such a way that the server can calculate a aggregation of raw models without knowing each individual values?
One-time pad based on Diffie-Hellman Key Exchange
I employed a technique from the original federated learning paper, which explains how clients send a masked model to a server and a server aggregates them to obtain a global model removing the masks at the end.
In summary, clients generate pair-wise mask, add/subtract it from their raw model so that they cancel with each other when the server aggregates them.
Let’s say we set clients in a ring topology as below and lable the pair-wise noise/mask between paired clients as mstart − node, end − node. (labeling nodes counterclockwise in this case)
- Between
client1
andclient2
, they generate a shared mask m1, 2 - Between
client2
andclient3
, they generate a shared mask m2, 3 - Between
client3
andclient1
, they generate a shared mask m3, 1
Imagine, each client add a mask they generated with their right neighbor client, and subtract a mask they generated with the left neighbor client.
For each client n, masked model Mn is calculated by raw model Rn + mask with the right neighbor mright - mask with the left neighbor mleft.
For example, - client1: masked model M1 = raw model R1 + m1, 2 - m3, 1 - client2: masked model M2 = raw model R2 + m2, 3 - m1, 2 - client3: masked model M3 = raw model R3 + m3, 1 - m2, 3
Then, when a server sums up the masked models Mn,
M1 + M2 + M3 = R1 + m1, 2 - m3, 1 + R2 + m2, 3 - m1, 2 + R3 + m3, 1 - m2, 3 = R1 + R2 + R3
You can see how shared masks between each client will cancel out, and that way server can calculate the sum of raw models Rn without knowing its individual values.
About the privacy on raw models Rn, each client can only calculate the masks with their own neighbors. For example: - client1 does not know m2, 3 => cannot reconstruct neither R2 or R3 - client2 does not know m3, 1 => cannot reconstruct neither R1 or R3 - client3 does not know m1, 2 => cannot reconstruct neither R1 or R2
How can this be done more concretely? A pair of clients can generate a shared mask using (Elliptic-curve) Diffie-Hellman Key Exchange protocol as follows: 1. Each client generates their own public-private key pair. 2. Clients publish their public key. 3. Each client locally computes the shared mask by multiplying their private key and their neighbors public key.
For the shared key generation, I used this ECDH Library inside zk-kit.noir library set developed by PSE.
You can check the whole masking prover here.
3. Aggregation in server
Aggregation process was fairly simple. The server first has to fetch the published masked models from blockchain, and sum them up element wise.
// Sum up weights and bias at the same index for each class, submitted by clients 1-3
// |class1 | |class2 | |class3 |
// Model 1 (from client 1): [w111, w112, w113, w114, b11] [w121, w122, w123, w124, b12] [w131, w132, w133, w134, b13]
// Model 2 (from client 2): [w211, w212, w213, w214, b21] [w221, w222, w223, w224, b22] [w231, w232, w233, w234, b23]
// Model 3 (from client 3): [w311, w312, w313, w314, b31] [w321, w322, w323, w324, b32] [w331, w332, w333, w334, b33]
// Aggregated global model: [w111+w211+w311, w112+w212+w312,...,b11+b21+b31]...
One thing to mention is, since I wanted to perform weighted average
for the model aggregation, clients actually submit
weights * number of samples
and
bias * number of samples
along with the
number of samples
they used for their training. (They
append number of samples
in an array of local models
submitted to blockchain)
The server divides the sum for weights
and
bias
by the total number of samples at the end. You can
check the aggregation prover here.
Notes on fixed-point arithmetic
When writing masking and aggregation provers, I had to be careful
about the fixed-point arithmetic range checks. In
machine learning, you almost always get decimal numbers in your
operation. However, since you cannot directly express decimal numbers
inside zk circuit, there is this technique of using fixed-point
arithmetic: you scale small decimal numbers by some fixed factor and
use the first half field elements (<= ~126bits) as positive numbers,
and the second half (>= ~126bits) to represent negative numbers.
Following this encoding rule, you need to add a bit-size check for the
operands such as assert_bitsize::<n>
before
performing each arithmetic operation in order to not overflow from the
field. (With haskcloak’s noir-mpc-ml
library, you can call these assertions in a more customizable way.
Underlying Quantized
struct will not automatically assert them before arithmetic
operations. That way, you can reduce the number of constraints and make
the zk circuit more performant.
For more concreteness, I added detailed comments before each
assert_bitsize
in my code, so if you’re curious about how
to do safe addition and multiplication in fixed-point arithmetic, you
can go check them in my masking
prover and aggregation
prover.
Future Research Direction
There are quite a few things I want to research more and add to the system.
1. Training dataset validation
Based on the assumption of not trusting the clients (that’s why we
ask them to prove the correct execution of local training and masking!),
we should probably check that they used valid input data as well. Many
ways of so-called “data poisoning” attacks are known in federated
learning. Namely, clients can maliciously use invalid (or carefully
crafted) training inputs such that they can manipulate the global model
performance. In order to prevent such attacks, how can we perform some
input validation while keeping their inputs private? I just started
researching it, but it generally seems that you can check a local model
update to follow some distribution and see whether the client
manipulated their training inputs.
This talk (which I got the above picture from) explains an interesting research on how to combat backdoor attacks - a type of data poisoning attacks, without entirely eliminating the manipulated model updates from the malicious clients (so that they can keep the global model accuracy fairly high). If we were to employ such validation, it has to be done on the client-side inside ZK before masking the model.
2. ZK circuit for more complex training algorithm… or fine-tuning?
Currently, clients locally execute a fairly simple training algorithm, logistic regression inside ZK. But can we use more complex algorithms such as neural networks? I’m looking into Zator project to see if there is anything reusable for training. Also, maybe instead of executing the whole training on the client side, we can broadcast pre-trained model onchain, and let client perform just fine-tuning on top of it. Doing so, how much can we reduce the computational cost on the client side? (Note that this task is more of my intellectual curiosity. Actually for verifiable federated learning use case ideas in my mind, I believe using models with more interpretability such as logistic regression makes more sense)
3. Storing local model updates offchain
Each client currently submits a local model — an array of 4 weights
and 1 bias for 3 classes
(e.g. Model 1 (from client 1): [w111, w112, w113, w114, b11] [w121, w122, w123, w124, b12] [w131, w132, w133, w134, b13]
as I’ve shown in section 3. Aggregation in server) to
blockchain directly. This works because the models are extremely small,
but what if the parameter size grows? Then I can easily switch to a
design where clients only publish hash of the local models on-chain,
while uploading the full local model to a decentralized storage such as
IPFS. When the server retrieves local models from IPFS, they recompute
the hash of the fetched file and verify it matches the hash stored
on-chain.
4. Batched/Packed secret sharing for masking models
This is a complete change in the cryptographic technique to mask the models.
For better security
Imagine you’re a clientn which publishes a masked model Mn. If an attacker colludes with your two neighbors clientn − 1 and clientn + 1 (receiving their secret key skn − 1, skn + 1 respectively), an attacker can reconstruct your raw model Rn (= Mn - maskn, n + 1 + maskn − 1, n). This is way worse than t-out-of-n threshold security model of MPC.
For better efficiency
In addition, current one-time pad technique requires us to generate a mask for each model parameter, meaning that the required number of the masks grows linearly with the size of the model parameters. Also, because we employ fixed-point arithmetic, each mask is constrained to be less than ~124bits, but if we employ a different encoding technique of decimal numbers inside ZK, mask bitsize should also grows linearly with the each parameter bitsize for good security.
I’m now looking into a technique called batched/packed secret sharing, which allows us to encode multiple secrets (e.g. weights and biases in our case) into a single polynomial. A distrusted party (e.g. a server) can then receive the shares and compute the elementwise sum across different secrets without learning their individual values.
I’m not entirely sure about the computational efficiency or interaction overhead when it comes to multiplication (which the server needs to perform for weighted averaging, but we could potentially offload that to the clients after they fetch the global model), but this MPC-based approach might make more sense if we prioritize security and asymptotic efficiency.
Alternatively, we can create fully connected graph topology so that n − 1 clients need to collude to reveal the raw models. Though this approach comes with a cost of O(nm) shared mask generations (n = #clients, m = #model size) .
5. Clients dropouts tolerance + real-time join
Right now this system does not tolerate any client dropouts after they sign up for the system. If the server cannot obtain all the local models to aggregate, the masks added/subtracted at each remained client won’t cancel out for the partial aggregation, so it will just be a gibberish value. I’m looking into this paper which seems to solve this problem by bringing (shared-secret based) MPC in before the masking. Their protocol says each client has to communicate with O(log(n)) number of other clients (where n is the total number of clients), in order for the server to obtain the correct aggregated model values, even after some clients drop out in the middle of the protocol. Also, is there any way to allow new client to join the training without requiring the previous nodes to redo the masking?
6. Reward system
This is probably the coolest extension that this system can have. Clients can get compensated for contributing their own data to train a model, while there is privacy protection on their data. But what should be the right metrics to calculate your contribution? There are several mathematical functions to measure how much a data point influenced the output model parameters during the training, such as this classic influence function. Or maybe, the key question to a fair attribution is “who taught AI about it earlier than anyone else?” rather than how much you influenced AI’s learning over time. So the reward should be given based on the time ranking of data provision. This topic involves more philosophical thinking, or the discussion should be done with concrete use cases, so I will write more about the design of the reward system in a separate post.
Special thanks to Sora, Ying Tong, Pierre, Timofey for giving me suggestion/ideas for the future research direction, and other friends (especially those that I spent time with at Community Privacy Residency this spring!) for discussion.
If you have any feedback or comments on this post and are willing to engage in a meaningful discussion, please leave them in the HackMD draft: https://hackmd.io/@yuriko/BJ1ptkh0yx