# Cmput 455 sample code
# How often do bandits based on Bernoulli experiments make the wrong choice?
# Written by Martin Mueller
from scipy.stats import binom
# If single experiment returns 1 with probability p:
# What is the probability of getting exactly k wins in n experiments?
# Also see:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html
def binomial(k, n, p):
return binom.pmf(k, n, p)
# probability that winrate of n1 tries is worse that k2/n2
# Sum over all outcomes k1 for which winrate k1/n1 < k2/n2
# for outcomes where winrates are equal, we count the
# chance of making the right decision as 50-50
# To avoid floating point errors, multiply all by n1*n2 to keep
# computations to integers.
def worse(p1, n1, k2, n2):
sum = 0
k1 = 0
while(k1*n2 <= k2*n1):
prob = binomial(k1, n1, p1)
if (k1*n2 < k2*n1):
sum += prob
else: # winrate equal, count as half.
sum += prob/2
k1 += 1
return sum
# arm1 with win probability p1
# arm2 with win probability p2
# p1 is better than p2
# p1 was pulled n1 times
# p2 was pulled n2 times
# What is the probability that p2 got the better evaluation?
def wrongChoice (n1, p1, n2, p2):
assert p1 > p2
sum = 0
for k in range(n2+1): # k wins out of n2 for arm 2
sum += binomial(k, n2, p2) * worse(p1, n1, k, n2)
# prob that: arm2 has k wins, AND arm1 has worse winrate than arm2
return sum
def test(p1, p2, maxN, printall = True):
print("p1 = {}, p2 = {}".format(p1, p2))
nextPrint = 1
for n in range(1, maxN + 1):
result = wrongChoice (n, p1, n, p2)
if (n == nextPrint):
print("Both have {} simulations. Prob. of wrong arm choice {}"
.format(n, result))
if (printall):
nextPrint += 1
else:
nextPrint *= 2
test(p1 = 0.9, p2 = 0.2, maxN = 10)
test(p1 = 0.5, p2 = 0.4, maxN = 10)
test(p1 = 0.5, p2 = 0.49, maxN = 10)
test(p1 = 0.5, p2 = 0.49999, maxN = 10)
test(p1 = 0.9, p2 = 0.2, maxN = 32, printall = False)
test(p1 = 0.5, p2 = 0.4, maxN = 32, printall = False)
test(p1 = 0.5, p2 = 0.49, maxN = 32, printall = False)
test(p1 = 0.5, p2 = 0.49999, maxN = 32, printall = False)