奧胡斯大學(xué)密碼學(xué)PhD、Datadog機(jī)器學(xué)習(xí)工程師Morten Dahl介紹了如何基于安全多方計(jì)算協(xié)議實(shí)現(xiàn)私密深度學(xué)習(xí)模型。
受最近一篇混合深度學(xué)習(xí)和同態(tài)加密的博客的啟發(fā)(見(jiàn)基于Numpy實(shí)現(xiàn)同態(tài)加密神經(jīng)網(wǎng)絡(luò)),我覺(jué)得使用安全多方計(jì)算(secure multi-party computation)替換同態(tài)加密實(shí)現(xiàn)深度學(xué)習(xí)會(huì)很有趣。
在本文中,我們將從頭構(gòu)建一個(gè)簡(jiǎn)單的安全多方計(jì)算協(xié)議,然后嘗試基于它訓(xùn)練簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)進(jìn)行基本布爾值計(jì)算。本文的相關(guān)代碼可以通過(guò)GitHub取得(mortendahl/privateml/simple-boolean-functions)。 假設(shè)有未串通的三方P0、P1、P2,愿意一起進(jìn)行計(jì)算,即訓(xùn)練神經(jīng)網(wǎng)絡(luò)并使用它們進(jìn)行預(yù)測(cè);然而,出于一些理由,他們不愿意泄露學(xué)習(xí)好的模型。同時(shí)假定有些用戶(hù)愿意在保持私密的前提下提供訓(xùn)練數(shù)據(jù),有些用戶(hù)也有興趣在他們的輸入保持私密的前提下使用學(xué)習(xí)好的模型。
為了能夠做到這一點(diǎn),我們需要安全地在特定精度下計(jì)算有理數(shù);具體而言,對(duì)它們做加法和乘法。我們同時(shí)需要計(jì)算sigmoid函數(shù)1/(1+np.exp(-x)),這一函數(shù)的傳統(tǒng)形式在安全設(shè)定下會(huì)導(dǎo)致驚人沉重的運(yùn)算。因此,我們將依照基于Numpy實(shí)現(xiàn)同態(tài)加密神經(jīng)網(wǎng)絡(luò)中的做法,使用多項(xiàng)式逼近sigmoid函數(shù),不過(guò)我們會(huì)進(jìn)行一點(diǎn)優(yōu)化。
安全多方計(jì)算
同態(tài)加密(Homomorphic Encryption,HE)和安全多方計(jì)算(secure Multi-Party Computation,MPC)是現(xiàn)代密碼學(xué)中密切相關(guān)的兩個(gè)領(lǐng)域,常?;ハ嗍褂脤?duì)方的技術(shù)以便解決大致相同的問(wèn)題:計(jì)算接受私密數(shù)據(jù)輸入的函數(shù)而不泄露任何東西,除了(可選)最終輸出。例如,在我們的私密機(jī)器學(xué)習(xí)設(shè)定下,兩種技術(shù)可以用來(lái)訓(xùn)練我們的模型并進(jìn)行預(yù)測(cè)(不過(guò)在HE的情形下,如果數(shù)據(jù)來(lái)自使用不同加密鑰的用戶(hù),需要做一些專(zhuān)門(mén)的技術(shù)處理)。
就其本身而言,從高層看,HE經(jīng)常可以用MPC替換,反之亦然。至少就今天而言,兩者的區(qū)別大致是HE不怎么需要交互,但是需要昂貴的計(jì)算,而MPC的計(jì)算很廉價(jià),但需要大量交互。換句話(huà)說(shuō),MCP用兩方或多方間的交互取代了昂貴的計(jì)算。
目前而言,這在實(shí)踐中提供了更好的性能,以至于人們可以主張MCP是明顯更成熟的技術(shù)——作為這一主張的依據(jù),已經(jīng)存在好幾家公司提供基于MPC的服務(wù)。
定點(diǎn)數(shù)算術(shù)
運(yùn)算將在一個(gè)有限域上進(jìn)行,因此我們首先需要決定如何將有理數(shù)r表示為域元素,即取自0, 1, ..., Q-1的整數(shù)x(Q為質(zhì)數(shù))。我們將采用典型的做法,根據(jù)固定的精度放大每個(gè)有理數(shù),比如,在6位精度的情形下,我們將放大10**6倍,然后將結(jié)果的整數(shù)部分作為定點(diǎn)數(shù)表示。例如,Q = 10000019時(shí),我們得到encode(0.5) == 500000和encode(-0.5) == 10000019 - 500000 == 9500019。
def encode(rational):
upscaled = int(rational * 10**6)
field_element = upscaled % Q
return field_element
def decode(field_element):
upscaled = field_element if field_element <= Q/2else field_element - Q
rational = upscaled / 10**6
return rational
注意,在這一表示下,加法是直截了當(dāng)?shù)模?r * 10**6) + (s * 10**6) == (r + s) * 10**6,而乘法添加了額外的放大因子,我們需要處理掉以保持精度和避免爆掉數(shù)字:(r * 10**6) * (s * 10**6) == (r * s) * 10**6 * 10**6。
共享和重建數(shù)據(jù)
編碼輸入后,每個(gè)用戶(hù)接著需要一種和他方共享數(shù)據(jù)的方式,以便用于計(jì)算,不過(guò),數(shù)據(jù)需要保持私密。
為了達(dá)到這一點(diǎn),我們需要的配料是秘密共享(secret sharing)。秘密共享將一個(gè)值以某種方式分成三份,任何見(jiàn)到少于三份數(shù)據(jù)的人,無(wú)法得知關(guān)于值的任何信息;然而,一旦見(jiàn)到所有三份,可以輕易地重建值。
出于簡(jiǎn)單性考慮,這里我們將使用復(fù)制秘密共享(replicated secret sharing),其中每方收到不止一份數(shù)據(jù)。具體而言,私密值x分成部分x0、x1、x2,滿(mǎn)足x == x0 + x1 + x2。P0方收到(x0,x1),P1收到(x1,x2),P2收到(x2,x0)。不過(guò)本教程中這一點(diǎn)將是隱式的,本文會(huì)直接將共享的x儲(chǔ)存為由三部分組成的向量[x0, x1, x2]。
def share(x):
x0 = random.randrange(Q)
x1 = random.randrange(Q)
x2 = (x - x0 - x1) % Q
return [x0, x1, x2]
當(dāng)兩方以上同意將一個(gè)值表露給某人時(shí),他們直接發(fā)送他們所有的部分,從而使重建得以進(jìn)行。
def reconstruct(shares):
return sum(shares) % Q
然而,如果部分是以下小節(jié)提到的一次或多次安全運(yùn)算的結(jié)果,出于私密性考慮,我們?cè)谥亟ㄇ斑M(jìn)行一次再共享。
def reshare(xs):
Y = [ share(xs[0]), share(xs[1]), share(xs[2]) ]
return [ sum(row) % Q for row in zip(*Y) ]
嚴(yán)格來(lái)說(shuō)這是不必要的,但是進(jìn)行這一步可以更容易地說(shuō)明為什么協(xié)議是安全的;直觀地,它確保分享的部分是新鮮的,不包含關(guān)于我們用于計(jì)算結(jié)果的數(shù)據(jù)的信息。
加法和減法
這樣我們已經(jīng)可以進(jìn)行安全的加法和減法運(yùn)算了:每方直接加減其擁有的部分,由于(x0 + x1 + x2) + (y0 + y1 + y2) == (x0 + y0) + (x1 + y1) + (x2 + y2),通過(guò)這一操作可以得到x + y的三部分(技術(shù)上說(shuō)應(yīng)該是reconstruct(x) + reconstruct(y),但是隱式寫(xiě)法更易讀)。
def add(x, y):
return [ (xi + yi) % Q for xi, yi in zip(x, y) ]
def sub(x, y):
return [ (xi - yi) % Q for xi, yi in zip(x, y) ]
注意這不需要進(jìn)行任何通訊,因?yàn)檫@些都是本地運(yùn)算。
乘法
由于每方擁有兩個(gè)部分,乘法可以通過(guò)類(lèi)似上面提到的加法和減法的方式進(jìn)行,即,每方基于已擁有的部分計(jì)算一個(gè)新部分。具體而言,對(duì)下面的代碼中定義的z0、z1、z2而言,我們有x * y == z0 + z1 + z2(技術(shù)上說(shuō)……)
然而,每方擁有兩個(gè)部分的不變性沒(méi)有滿(mǎn)足,而像P1直接將z1發(fā)給P0這樣的做法是不安全的。一個(gè)簡(jiǎn)單的修正是直接將每份zi當(dāng)成私密輸入共享;這樣就得到了一個(gè)正確而安全的共享w(乘積)。
def mul(x, y):
# 本地運(yùn)算
z0 = (x[0]*y[0] + x[0]*y[1] + x[1]*y[0]) % Q
z1 = (x[1]*y[1] + x[1]*y[2] + x[2]*y[1]) % Q
z2 = (x[2]*y[2] + x[2]*y[0] + x[0]*y[2]) % Q
# 重共享和分發(fā);這里需要通訊
Z = [ share(z0), share(z1), share(z2) ]
w = [ sum(row) % Q for row in zip(*Z) ]
# 將雙精度轉(zhuǎn)回單精度
v = truncate(w)
return v
不過(guò)還有一個(gè)問(wèn)題,如前所述,reconstruct(w)具有雙精度:它編碼時(shí)使用的放大因子是10**6 * 10**6,而不是10**6。在不安全設(shè)定下,我們本可以通過(guò)標(biāo)準(zhǔn)的除法(除以10**6)來(lái)修正這一點(diǎn),然而,由于我們操作的是有限域中的秘密共享元素,這變得不那么直截了當(dāng)了。
除以一個(gè)公開(kāi)的常量,這里是10**6,足夠簡(jiǎn)單:我們直接將部分乘以其域中的逆元素10**(-6)。對(duì)某v和u < 10**6,如果reconstruct(w) == v * 10**6 + u,那么乘以逆元素得到v + u * 10**(-6),那么v就是我們要找到的值。在不安全設(shè)定下,殘值u * 10**(-6)足夠小,可以通過(guò)取整消除。與此不同,在安全設(shè)定下,基于有限域元素,這一語(yǔ)義丟失了,我們需要通過(guò)其他方法擺脫殘值。
一種方法是確保u == 0。具體而言,如果我們事先知道u,那么我們可以不對(duì)w作除法,而對(duì)w' == (w - share(u))作除法,接著我們就如愿以?xún)敚玫絭' == v和u' == 0,即,沒(méi)有任何殘值。
剩下的問(wèn)題當(dāng)然是如何安全地得到u,以便計(jì)算w'。具體細(xì)節(jié)見(jiàn)CS’10,不過(guò)基本的思路是首先在w上加上一個(gè)大的掩碼,將掩碼后的值表露給其中一方,使其得以計(jì)算掩碼后的u。最后,共享和解掩碼這一掩碼后的值,然后計(jì)算w'。
def truncate(a):
# 映射到正值范圍
b = add(a, share(10**(6+6-1)))
# 應(yīng)用僅有P0知道的掩碼,然后重建掩碼后的b,發(fā)送給P1或P2
mask = random.randrange(Q) % 10**(6+6+KAPPA)
mask_low = mask % 10**6
b_masked = reconstruct(add(b, share(mask)))
# 提取低位數(shù)字
b_masked_low = b_masked % 10**6
b_low = sub(share(b_masked_low), share(mask_low))
# 去除低位數(shù)字
c = sub(a, b_low)
# 除法
d = imul(c, INVERSE)
return d
注意上面的imul是本地操作,將每個(gè)共享部分乘以公開(kāi)的常數(shù),這里是10**6的域中逆元素。
安全數(shù)據(jù)類(lèi)型
最后,我們將以上過(guò)程包裹進(jìn)一個(gè)定制的抽象數(shù)據(jù)類(lèi)型,這樣我們之后表達(dá)神經(jīng)網(wǎng)絡(luò)的時(shí)候就可以使用NumPy了。
classSecureRational(object):
def __init__(self, secret=None):
self.shares = share(encode(secret)) if secret isnotNoneelse []
return z
def reveal(self):
return decode(reconstruct(reshare(self.shares)))
def __repr__(self):
return"SecureRational(%f)" % self.reveal()
def __add__(x, y):
z = SecureRational()
z.shares = add(x.shares, y.shares)
return z
def __sub__(x, y):
z = SecureRational()
z.shares = sub(x.shares, y.shares)
return z
def __mul__(x, y):
z = SecureRational()
z.shares = mul(x.shares, y.shares)
return z
def __pow__(x, e):
z = SecureRational(1)
for _ in range(e):
z = z * x
return z
基于這一類(lèi)型,我們可以安全地對(duì)這樣的值進(jìn)行操作:
x = SecureRational(.5)
y = SecureRational(-.25)
z = x * y
assert(z.reveal() == (.5) * (-.25))
此外,需要調(diào)試的時(shí)候,我們可以切換為不安全類(lèi)型而不需要修改其余(神經(jīng)網(wǎng)絡(luò))代碼。再比如,我們可以隔離計(jì)數(shù)器的使用,查看進(jìn)行了多少次乘法,進(jìn)而讓我們模擬下需要多少通訊。
深度學(xué)習(xí)
這里用“深度學(xué)習(xí)”這個(gè)術(shù)語(yǔ)屬于夸夸其談,因?yàn)槲覀冎皇呛?jiǎn)單地?cái)[弄了下基于Numpy實(shí)現(xiàn)同態(tài)加密神經(jīng)網(wǎng)絡(luò)中的神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)基本布爾值函數(shù)。
一個(gè)簡(jiǎn)單函數(shù)
第一個(gè)實(shí)驗(yàn)是訓(xùn)練網(wǎng)絡(luò)以識(shí)別序列中的第一位。下面的代碼中,X中的四行是輸入的訓(xùn)練數(shù)據(jù),y中相應(yīng)的列是所需輸出。
X = np.array([
[0,0,1],
[0,1,1],
[1,0,1],
[1,1,1]
])
y = np.array([[
0,
0,
1,
1
]]).T
我們將使用同樣的雙層網(wǎng)絡(luò),不過(guò)我們會(huì)將下面定義的sigmoid逼近函數(shù)作為參數(shù)。secure函數(shù)是一個(gè)簡(jiǎn)單的輔助函數(shù),將所有值轉(zhuǎn)換為我們的安全數(shù)據(jù)類(lèi)型。
classTwoLayerNetwork:
def __init__(self, sigmoid):
self.sigmoid = sigmoid
def train(self, X, y, iterations=1000):
# 初始化權(quán)重
self.synapse0 = secure(2 * np.random.random((3,1)) - 1)
# 訓(xùn)練
for i in range(iterations):
# 前向傳播
layer0 = X
layer1 = self.sigmoid.evaluate(np.dot(layer0, self.synapse0))
# 反向傳播
layer1_error = y - layer1
layer1_delta = layer1_error * self.sigmoid.derive(layer1)
# 更新
self.synapse0 += np.dot(layer0.T, layer1_delta)
def predict(self, X):
layer0 = X
layer1 = self.sigmoid.evaluate(np.dot(layer0, self.synapse0))
return layer1
同時(shí),我們將使用原文提出的sigmoid逼近,即標(biāo)準(zhǔn)麥克勞林/泰勒多項(xiàng)式的前五項(xiàng)。出于可讀性考慮,我這里用了一個(gè)簡(jiǎn)單多項(xiàng)式演算,有待進(jìn)一步優(yōu)化,比如使用秦九韶算法減少乘法的數(shù)目。
classSigmoidMaclaurin5:
def __init__(self):
ONE = SecureRational(1)
W0 = SecureRational(1/2)
W1 = SecureRational(1/4)
W3 = SecureRational(-1/48)
W5 = SecureRational(1/480)
self.sigmoid = np.vectorize(lambda x: W0 + (x * W1) + (x**3 * W3) + (x**5 * W5))
self.sigmoid_deriv = np.vectorize(lambda x: (ONE - x) * x)
def evaluate(self, x):
return self.sigmoid(x)
def derive(self, x):
return self.sigmoid_deriv(x)
實(shí)現(xiàn)了這個(gè)之后我們就可以訓(xùn)練和演算網(wǎng)絡(luò)了(細(xì)節(jié)見(jiàn)notebook),這里使用了10000次迭代。
# 設(shè)置隨機(jī)數(shù)種子以獲得可復(fù)現(xiàn)的結(jié)果
random.seed(1)
np.random.seed(1)
# 選擇逼近
sigmoid = SigmoidMaclaurin5()
# 訓(xùn)練
network = TwoLayerNetwork(sigmoid)
network.train(secure(X), secure(y), 10000)
# 演算預(yù)測(cè)
evaluate(network)
注意訓(xùn)練數(shù)據(jù)在輸入網(wǎng)絡(luò)之前是安全共享的,并且學(xué)習(xí)到的權(quán)重從未泄露。預(yù)測(cè)同理,只有網(wǎng)絡(luò)的用戶(hù)知道輸入和輸出。
Error: 0.00539115
Error: 0.0025606125
Error: 0.00167358
Error: 0.001241815
Error: 0.00098674
Error: 0.000818415
Error: 0.0006990725
Error: 0.0006100825
Error: 0.00054113
Error: 0.0004861775
Layer0 weights:
[[SecureRational(4.974135)]
[SecureRational(-0.000854)]
[SecureRational(-2.486387)]]
Prediction on [000]: 0 (0.50000000)
Prediction on [001]: 0 (0.00066431)
Prediction on [010]: 0 (0.49978657)
Prediction on [011]: 0 (0.00044076)
Prediction on [100]: 1 (5.52331855)
Prediction on [101]: 1 (0.99969213)
Prediction on [110]: 1 (5.51898314)
Prediction on [111]: 1 (0.99946841)
從上面的演算來(lái)看,神經(jīng)網(wǎng)絡(luò)確實(shí)看起來(lái)學(xué)習(xí)到了所要求的函數(shù),在未見(jiàn)輸入上也能給出正確的預(yù)測(cè)。
稍微高級(jí)些的函數(shù)
在下一個(gè)實(shí)驗(yàn)中,神經(jīng)網(wǎng)絡(luò)無(wú)法像之前一樣鏡像三個(gè)組件的其中一個(gè),從直觀上說(shuō),需要計(jì)算第一位和第二位的異或(第三位是偏離)。
X = np.array([
[0,0,1],
[0,1,1],
[1,0,1],
[1,1,1]
])
y = np.array([[
0,
1,
1,
0
]]).T
如Numpy實(shí)現(xiàn)神經(jīng)神經(jīng)網(wǎng)絡(luò):反向傳播一文所解釋的,使用雙層神經(jīng)網(wǎng)絡(luò)只能給出無(wú)意義的結(jié)果,本質(zhì)上是在說(shuō)“讓我們?nèi)右幻队矌虐伞薄?/p>
Error: 0.500000005
Error: 0.5
Error: 0.5000000025
Error: 0.5000000025
Error: 0.5
Error: 0.5
Error: 0.5
Error: 0.5
Error: 0.5
Error: 0.5
Layer0 weights:
[[SecureRational(0.000000)]
[SecureRational(0.000000)]
[SecureRational(0.000000)]]
Prediction on [000]: 0 (0.50000000)
Prediction on [001]: 0 (0.50000000)
Prediction on [010]: 0 (0.50000000)
Prediction on [011]: 0 (0.50000000)
Prediction on [100]: 0 (0.50000000)
Prediction on [101]: 0 (0.50000000)
Prediction on [110]: 0 (0.50000000)
Prediction on [111]: 0 (0.50000000)
提議的補(bǔ)救措施是在網(wǎng)絡(luò)中引入另一層:
classThreeLayerNetwork:
def __init__(self, sigmoid):
self.sigmoid = sigmoid
def train(self, X, y, iterations=1000):
# 初始權(quán)重
self.synapse0 = secure(2 * np.random.random((3,4)) - 1)
self.synapse1 = secure(2 * np.random.random((4,1)) - 1)
# 訓(xùn)練
for i in range(iterations):
# 前向傳播
layer0 = X
layer1 = self.sigmoid.evaluate(np.dot(layer0, self.synapse0))
layer2 = self.sigmoid.evaluate(np.dot(layer1, self.synapse1))
# 反向傳播
layer2_error = y - layer2
layer2_delta = layer2_error * self.sigmoid.derive(layer2)
layer1_error = np.dot(layer2_delta, self.synapse1.T)
layer1_delta = layer1_error * self.sigmoid.derive(layer1)
# 更新
self.synapse1 += np.dot(layer1.T, layer2_delta)
self.synapse0 += np.dot(layer0.T, layer1_delta)
def predict(self, X):
layer0 = X
layer1 = self.sigmoid.evaluate(np.dot(layer0, self.synapse0))
layer2 = self.sigmoid.evaluate(np.dot(layer1, self.synapse1))
return layer2
然而,如果我們采用之前的方式訓(xùn)練網(wǎng)絡(luò),即使僅僅迭代100次,我們都將面臨一個(gè)奇怪的現(xiàn)象:突然之間,誤差、權(quán)重、預(yù)測(cè)分?jǐn)?shù)爆炸了,給出混亂的結(jié)果。
Error: 0.496326875
Error: 0.4963253375
Error: 0.50109445
Error: 4.50917445533e+22
Error: 4.20017387687e+22
Error: 4.38235385094e+22
Error: 4.65389939428e+22
Error: 4.25720845129e+22
Error: 4.50520005372e+22
Error: 4.31568874384e+22
Layer0 weights:
[[SecureRational(970463188850515564822528.000000)
SecureRational(1032362386093871682551808.000000)
SecureRational(1009706886834648285970432.000000)
SecureRational(852352894255113084862464.000000)]
[SecureRational(999182403614802557534208.000000)
SecureRational(747418473813466924711936.000000)
SecureRational(984098986255565992230912.000000)
SecureRational(865284701475152213311488.000000)]
[SecureRational(848400149667429499273216.000000)
SecureRational(871252067688430631387136.000000)
SecureRational(788722871059090631557120.000000)
SecureRational(868480811373827731750912.000000)]]
Layer1 weights:
[[SecureRational(818092877308528183738368.000000)]
[SecureRational(940782003999550335877120.000000)]
[SecureRational(909882533376693496709120.000000)]
[SecureRational(955267264038446787723264.000000)]]
Prediction on [000]: 1 (41452089757570437218304.00000000)
Prediction on [001]: 1 (46442301971509056372736.00000000)
Prediction on [010]: 1 (37164015478651618328576.00000000)
Prediction on [011]: 1 (43504970843252146044928.00000000)
Prediction on [100]: 1 (35282926617309558603776.00000000)
Prediction on [101]: 1 (47658769913438164484096.00000000)
Prediction on [110]: 1 (35957624290517111013376.00000000)
Prediction on [111]: 1 (47193714919561920249856.00000000)
導(dǎo)致這一切的原因很簡(jiǎn)單,但也許乍看起來(lái)不是那么明顯(至少對(duì)我而言)。盡管(前五項(xiàng))麥克勞林/泰勒逼近sigmoid函數(shù)在前面的網(wǎng)絡(luò)中表現(xiàn)良好,當(dāng)我們進(jìn)一步推進(jìn)時(shí),它完全崩塌了,產(chǎn)生的結(jié)果不僅不精確,而且數(shù)量級(jí)也不對(duì)。因此很快摧毀了我們可能使用的任何有窮數(shù)字表示,即使在非安全設(shè)定下也是如此,數(shù)字開(kāi)始溢出了。
技術(shù)上說(shuō)sigmoid函數(shù)演算的點(diǎn)積變得太大了,就我所知,這意味著神經(jīng)網(wǎng)絡(luò)變得非常自信。就此而言,問(wèn)題在于我們的逼近不允許神經(jīng)網(wǎng)絡(luò)變得足夠自信,否則精確度會(huì)非常糟糕。
我不清楚基于Numpy實(shí)現(xiàn)同態(tài)加密神經(jīng)網(wǎng)絡(luò)是如何避免這一問(wèn)題的,我最好的猜測(cè)是較低的初始權(quán)重和alpha更新參數(shù)使它可能在迭代次數(shù)較低的情形下繞過(guò)這個(gè)坑(看起來(lái)是少于300次迭代)。無(wú)比歡迎任何關(guān)于這方面的評(píng)論。
逼近sigmoid
既然是我們的sigmoid逼近阻礙了我們學(xué)習(xí)更高級(jí)的函數(shù),那么很自然地,我們接下來(lái)嘗試使用麥克勞林/泰勒多項(xiàng)式的更多項(xiàng)。
如下所示,加到第9項(xiàng)(而不是第5項(xiàng))確實(shí)能稍微增加一點(diǎn)進(jìn)展,但這點(diǎn)進(jìn)展遠(yuǎn)遠(yuǎn)不夠。此外,它塌得更快了。
Error: 0.49546145
Error: 0.4943132225
Error: 0.49390536
Error: 0.50914575
Error: 7.29251498137e+22
Error: 7.97702462371e+22
Error: 7.01752029207e+22
Error: 7.41001528681e+22
Error: 7.33032620012e+22
Error: 7.3022511184e+22
...
或者我們?cè)撧D(zhuǎn)而使用更少的項(xiàng)以更好地牽制崩塌?比如,只加到第3項(xiàng)?這確實(shí)有點(diǎn)作用,能讓我們?cè)诒浪坝?xùn)練500次迭代而不是100次。
Error: 0.4821573275
Error: 0.46344183
Error: 0.4428059575
Error: 0.4168092675
Error: 0.388153325
Error: 0.3619875475
Error: 0.3025045425
Error: 0.2366579675
Error: 0.19651228
Error: 0.1748352775
Layer0 weights:
[[SecureRational(1.455894) SecureRational(1.376838)
SecureRational(-1.445690) SecureRational(-2.383619)]
[SecureRational(-0.794408) SecureRational(-2.069235)
SecureRational(-1.870023) SecureRational(-1.734243)]
[SecureRational(0.712099) SecureRational(-0.688947)
SecureRational(0.740605) SecureRational(2.890812)]]
Layer1 weights:
[[SecureRational(-2.893681)]
[SecureRational(6.238205)]
[SecureRational(-7.945379)]
[SecureRational(4.674321)]]
Prediction on [000]: 1 (0.50918230)
Prediction on [001]: 0 (0.16883382)
Prediction on [010]: 0 (0.40589161)
Prediction on [011]: 1 (0.82447640)
Prediction on [100]: 1 (0.83164009)
Prediction on [101]: 1 (0.83317334)
Prediction on [110]: 1 (0.74354671)
Prediction on [111]: 0 (0.18736629)
然而,誤差和預(yù)測(cè)很糟糕,也沒(méi)有多少空間供增加迭代次數(shù)了(大約在550次迭代處崩塌)。
插值
作為替代,我們可以放棄標(biāo)準(zhǔn)多項(xiàng)式逼近,轉(zhuǎn)而嘗試在區(qū)間上進(jìn)行多項(xiàng)式插值。這里主要的參數(shù)是多項(xiàng)式的項(xiàng)數(shù),我們希望它保持在一個(gè)較低的值,以提高效率。不過(guò),系數(shù)的精度也是相關(guān)參數(shù)。
# 我們想要逼近的函數(shù)
f_real = lambda x: 1/(1+np.exp(-x))
# 我們想要優(yōu)化的區(qū)間
interval = np.linspace(-10, 10, 100)
# 給定項(xiàng)數(shù),進(jìn)行多項(xiàng)式插值
degree = 10
coefs = np.polyfit(interval, f_real(interval), degree)
# 降低插值系數(shù)的精度
precision = 10
coefs = [ int(x * 10**precision) / 10**precision for x in coefs ]
# 逼近函數(shù)
f_interpolated = np.poly1d(coefs)
一同繪制標(biāo)準(zhǔn)逼近和插值多項(xiàng)式(紅色曲線(xiàn))的圖像我們看到了改進(jìn)的希望:我們無(wú)法避免在某點(diǎn)崩塌,但它的崩塌點(diǎn)顯然要大很多。
當(dāng)然,我們也可以嘗試其他項(xiàng)數(shù)、精度、區(qū)間的組合,如下所示,不過(guò)對(duì)我們的應(yīng)用而言,上面的參數(shù)組合看起來(lái)已經(jīng)足夠了。
現(xiàn)在讓我們回到我們的三層網(wǎng)絡(luò),我們定義一個(gè)新的Sigmoid逼近:
classSigmoidInterpolated10:
def __init__(self):
ONE = SecureRational(1)
W0 = SecureRational(0.5)
W1 = SecureRational(0.2159198015)
W3 = SecureRational(-0.0082176259)
W5 = SecureRational(0.0001825597)
W7 = SecureRational(-0.0000018848)
W9 = SecureRational(0.0000000072)
self.sigmoid = np.vectorize(lambda x: \
W0 + (x * W1) + (x**3 * W3) + (x**5 * W5) + (x**7 * W7) + (x**9 * W9))
self.sigmoid_deriv = np.vectorize(lambda x:(ONE - x) * x)
def evaluate(self, x):
return self.sigmoid(x)
def derive(self, x):
return self.sigmoid_deriv(x)
……然后開(kāi)始訓(xùn)練:
# 設(shè)置隨機(jī)數(shù)種子以獲得可復(fù)現(xiàn)的結(jié)果
random.seed(1)
np.random.seed(1)
# 選擇逼近
sigmoid = SigmoidInterpolated10()
# 訓(xùn)練
network = TwoLayerNetwork(sigmoid)
network.train(secure(X), secure(y), 10000)
# 演算預(yù)測(cè)
evaluate(network)
現(xiàn)在,盡管我們運(yùn)行了10000次迭代,沒(méi)有發(fā)生崩塌,預(yù)測(cè)表現(xiàn)也提升了,只有一個(gè)預(yù)測(cè)錯(cuò)誤([0 1 0])。
Error: 0.0384136825
Error: 0.01946007
Error: 0.0141456075
Error: 0.0115575225
Error: 0.010008035
Error: 0.0089747225
Error: 0.0082400825
Error: 0.00769687
Error: 0.007286195
Error: 0.00697363
Layer0 weights:
[[SecureRational(3.208028) SecureRational(3.359444)
SecureRational(-3.632461) SecureRational(-4.094379)]
[SecureRational(-1.552827) SecureRational(-4.403901)
SecureRational(-3.997194) SecureRational(-3.271171)]
[SecureRational(0.695226) SecureRational(-1.560569)
SecureRational(1.758733) SecureRational(5.425429)]]
Layer1 weights:
[[SecureRational(-4.674311)]
[SecureRational(5.910466)]
[SecureRational(-9.854162)]
[SecureRational(6.508941)]]
Prediction on [000]: 0 (0.28170669)
Prediction on [001]: 0 (0.00638341)
Prediction on [010]: 0 (0.33542098)
Prediction on [011]: 1 (0.99287968)
Prediction on [100]: 1 (0.74297185)
Prediction on [101]: 1 (0.99361066)
Prediction on [110]: 0 (0.03599433)
Prediction on [111]: 0 (0.00800036)
注意錯(cuò)誤情形的分?jǐn)?shù)不是完全偏離,某種程度上和正確預(yù)測(cè)的零值有所不同。再運(yùn)行5000次迭代看起來(lái)不會(huì)改善這一點(diǎn),這時(shí)我們已經(jīng)快要崩塌了。
結(jié)語(yǔ)
本文重點(diǎn)介紹了一個(gè)簡(jiǎn)單的安全多方計(jì)算協(xié)議,而沒(méi)有顯式地論證開(kāi)頭提到的安全多方計(jì)算比同態(tài)加密更高效。我們看到,使用非?;镜牟僮鲗?shí)現(xiàn)私密機(jī)器學(xué)習(xí)確實(shí)是可能的。
也許更需要批評(píng)的是我們沒(méi)有測(cè)量運(yùn)行協(xié)議需要的通訊量,主要是每次乘法時(shí)所需交換的一些消息。基于這一簡(jiǎn)單的協(xié)議進(jìn)行大量計(jì)算的話(huà),顯然讓三方通過(guò)高速局域網(wǎng)連接會(huì)比較好。不過(guò)更高級(jí)的協(xié)議不僅減少了來(lái)回收發(fā)的數(shù)據(jù)量,同時(shí)改善了其他性質(zhì),比如回合數(shù)(像亂碼電路(garbled circuits)就可以將回合數(shù)壓到一個(gè)小常數(shù))。
最后,本文基本上將協(xié)議和機(jī)器學(xué)習(xí)過(guò)程看成是正交的,讓后者僅僅以一種黑盒的方式使用前者(除了sigmoid計(jì)算)。兩者更好的配合需要兩個(gè)領(lǐng)域的專(zhuān)門(mén)知識(shí),但可能顯著地提升總體性能。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4777瀏覽量
100960 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5510瀏覽量
121345
原文標(biāo)題:基于安全多方計(jì)算協(xié)議實(shí)現(xiàn)私密深度學(xué)習(xí)模型
文章出處:【微信號(hào):jqr_AI,微信公眾號(hào):論智】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論