Skip to content

Commit 3203377

Browse files
authored
Merge pull request #2 from GreedyAIAcademy/master
update
2 parents e32df33 + 2476118 commit 3203377

12 files changed

+107706
-0
lines changed

10. TopicModels/LDA_tutorial.ipynb

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"![](attachment:image.png)"
8+
]
9+
},
10+
{
11+
"cell_type": "markdown",
12+
"metadata": {},
13+
"source": [
14+
"![](attachment:image.png)"
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"metadata": {},
20+
"source": []
21+
},
22+
{
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"### Gibbs sampler for LDA"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": 2,
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"# words\n",
36+
"import numpy as np\n",
37+
"W = np.array([0, 1, 2, 3, 4])\n",
38+
"\n",
39+
"# D := document words\n",
40+
"X = np.array([\n",
41+
" [0, 0, 1, 2, 2],\n",
42+
" [0, 0, 1, 1, 1],\n",
43+
" [0, 1, 2, 2, 2],\n",
44+
" [4, 4, 4, 4, 4],\n",
45+
" [3, 3, 4, 4, 4],\n",
46+
" [3, 4, 4, 4, 4]\n",
47+
"])\n",
48+
"\n",
49+
"N_D = X.shape[0] # num of docs\n",
50+
"N_V = W.shape[0] # num of words\n",
51+
"N_K = 2 # num of topics"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": 5,
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"import numpy as np\n",
61+
"# Dirichlet priors\n",
62+
"alpha = 1\n",
63+
"gamma = 1\n",
64+
"\n",
65+
"# Z := word topic assignment\n",
66+
"Z = np.zeros(shape=[N_D, N_V])\n",
67+
"\n",
68+
"for i in range(N_D):\n",
69+
" for l in range(N_V):\n",
70+
" Z[i, l] = np.random.randint(N_K) # randomly assign word's topic\n",
71+
"\n",
72+
"# Pi := document topic distribution\n",
73+
"theta = np.zeros([N_D, N_K])\n",
74+
"\n",
75+
"for i in range(N_D):\n",
76+
" theta[i] = np.random.dirichlet(alpha*np.ones(N_K))\n",
77+
"\n",
78+
"# phi := word topic distribution\n",
79+
"phi = np.zeros([N_K, N_V])\n",
80+
"\n",
81+
"for k in range(N_K):\n",
82+
" phi[k] = np.random.dirichlet(gamma*np.ones(N_V))\n",
83+
"\n",
84+
"for it in range(1000):\n",
85+
" # Sample from full conditional of Z\n",
86+
" # ---------------------------------\n",
87+
" for i in range(N_D):\n",
88+
" for v in range(N_V):\n",
89+
" # Calculate params for Z\n",
90+
" p_iv = np.exp(np.log(theta[i]) + np.log(phi[:, X[i, v]]))\n",
91+
" p_iv /= np.sum(p_iv)\n",
92+
"\n",
93+
" # Resample word topic assignment Z\n",
94+
" Z[i, v] = np.random.multinomial(1, p_iv).argmax()\n",
95+
"\n",
96+
" # Sample from full conditional of \\theta\n",
97+
" # ----------------------------------\n",
98+
" for i in range(N_D):\n",
99+
" m = np.zeros(N_K)\n",
100+
"\n",
101+
" # Gather sufficient statistics\n",
102+
" for k in range(N_K):\n",
103+
" m[k] = np.sum(Z[i] == k)\n",
104+
"\n",
105+
" # Resample doc topic dist.\n",
106+
" theta[i, :] = np.random.dirichlet(alpha + m)\n",
107+
"\n",
108+
" # Sample from full conditional of \\phi\n",
109+
" # ---------------------------------\n",
110+
" for k in range(N_K):\n",
111+
" n = np.zeros(N_V)\n",
112+
"\n",
113+
" # Gather sufficient statistics\n",
114+
" for v in range(N_V):\n",
115+
" for i in range(N_D):\n",
116+
" for l in range(N_V):\n",
117+
" n[v] += (X[i, l] == v) and (Z[i, l] == k)\n",
118+
"\n",
119+
" # Resample word topic dist.\n",
120+
" phi[k, :] = np.random.dirichlet(gamma + n)"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": 6,
126+
"metadata": {},
127+
"outputs": [
128+
{
129+
"name": "stdout",
130+
"output_type": "stream",
131+
"text": [
132+
"[[0.33941922 0.66058078]\n",
133+
" [0.12011984 0.87988016]\n",
134+
" [0.02318953 0.97681047]\n",
135+
" [0.80301107 0.19698893]\n",
136+
" [0.94780545 0.05219455]\n",
137+
" [0.84677554 0.15322446]]\n"
138+
]
139+
}
140+
],
141+
"source": [
142+
"print (theta)"
143+
]
144+
},
145+
{
146+
"cell_type": "code",
147+
"execution_count": null,
148+
"metadata": {},
149+
"outputs": [],
150+
"source": []
151+
},
152+
{
153+
"cell_type": "markdown",
154+
"metadata": {},
155+
"source": [
156+
"### Mini-homework - Collapsed Gibbs Sampler for LDA\n",
157+
"自己重新推导一下Collapsed Gibbs Sampler,并实现一下LDA类并利用上面给定的YELP数据集来测试一下,并跟sklearn的结果比较一下。"
158+
]
159+
},
160+
{
161+
"cell_type": "code",
162+
"execution_count": 7,
163+
"metadata": {},
164+
"outputs": [
165+
{
166+
"ename": "IndentationError",
167+
"evalue": "expected an indented block (<ipython-input-7-fe62b0541667>, line 7)",
168+
"output_type": "error",
169+
"traceback": [
170+
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-7-fe62b0541667>\"\u001b[0;36m, line \u001b[0;32m7\u001b[0m\n\u001b[0;31m def fit(self, X, y=None):\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m expected an indented block\n"
171+
]
172+
}
173+
],
174+
"source": [
175+
"class LDA:\n",
176+
" \"\"\"Latent Dirichlet allocation using collapsed Gibbs sampling\"\"\"\n",
177+
" \n",
178+
" def __init__():\n",
179+
" \n",
180+
"\n",
181+
" def fit(self, X, y=None):\n",
182+
" \"\"\"Fit the model with X.\"\"\"\n",
183+
" \n",
184+
"\n",
185+
" def fit_transform(self, X, y=None):\n",
186+
" \n",
187+
"\n",
188+
" def transform(self, X, max_iter=20, tol=1e-16):\n",
189+
" \n",
190+
" \n",
191+
" def loglikelihood(self):\n",
192+
" \"\"\"Calculate complete log likelihood, log p(w,z)\n",
193+
" Formula used is log p(w,z) = log p(w|z) + log p(z)\n",
194+
" \"\"\"\n",
195+
" def perplexity(self):\n",
196+
" \"\"\"Calculate the perplexity\"\"\"\n",
197+
" "
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": null,
203+
"metadata": {},
204+
"outputs": [],
205+
"source": []
206+
}
207+
],
208+
"metadata": {
209+
"kernelspec": {
210+
"display_name": "Python 3",
211+
"language": "python",
212+
"name": "python3"
213+
},
214+
"language_info": {
215+
"codemirror_mode": {
216+
"name": "ipython",
217+
"version": 3
218+
},
219+
"file_extension": ".py",
220+
"mimetype": "text/x-python",
221+
"name": "python",
222+
"nbconvert_exporter": "python",
223+
"pygments_lexer": "ipython3",
224+
"version": "3.7.1"
225+
}
226+
},
227+
"nbformat": 4,
228+
"nbformat_minor": 2
229+
}

0 commit comments

Comments
 (0)