forked from graykode/distribution-is-all-you-need
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dirichlet.py
63 lines (54 loc) · 1.61 KB
/
dirichlet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
Code by Tae-Hwan Hung(@graykode)
https://en.wikipedia.org/wiki/Dirichlet_distribution
3-Class Example
"""
from random import randint
import numpy as np
from matplotlib import pyplot as plt
def normalization(x, s):
"""
:return: normalizated list, where sum(x) == s
"""
return [(i * s) / sum(x) for i in x]
def sampling():
return normalization([randint(1, 100),
randint(1, 100), randint(1, 100)], s=1)
def gamma_function(n):
cal = 1
for i in range(2, n):
cal *= i
return cal
def beta_function(alpha):
"""
:param alpha: list, len(alpha) is k
:return:
"""
numerator = 1
for a in alpha:
numerator *= gamma_function(a)
denominator = gamma_function(sum(alpha))
return numerator / denominator
def dirichlet(x, a, n):
"""
:param x: list of [x[1,...,K], x[1,...,K], ...], shape is (n_trial, K)
:param a: list of coefficient, a_i > 0
:param n: number of trial
:return:
"""
c = (1 / beta_function(a))
y = [c * (xn[0] ** (a[0] - 1)) * (xn[1] ** (a[1] - 1))
* (xn[2] ** (a[2] - 1)) for xn in x]
x = np.arange(n)
return x, y, np.mean(y), np.std(y)
n_experiment = 1200
for ls in [(6, 2, 2), (3, 7, 5), (6, 2, 6), (2, 3, 4)]:
alpha = list(ls)
# random samping [x[1,...,K], x[1,...,K], ...], shape is (n_trial, K)
# each sum of row should be one.
x = [sampling() for _ in range(1, n_experiment + 1)]
x, y, u, s = dirichlet(x, alpha, n=n_experiment)
plt.plot(x, y, label=r'$\alpha=(%d,%d,%d)$' % (ls[0], ls[1], ls[2]))
plt.legend()
plt.savefig('graph/dirichlet.png')
plt.show()