Skip to content

Commit 15dc75e

Browse files
committed
added multivariate model
1 parent 4b37aab commit 15dc75e

1 file changed

Lines changed: 156 additions & 12 deletions

File tree

notebooks/bonus-exploration-finches.ipynb

Lines changed: 156 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,17 @@
7878
" nu = pm.Exponential('nu', lam=1/29.) + 1\n",
7979
" \n",
8080
" # Define the likelihood distribution for the data.\n",
81-
" depth = pm.StudentT('depth', \n",
81+
" depth = pm.StudentT('beak_depth', \n",
8282
" nu=nu,\n",
8383
" mu=mean_depth[df['species_enc']], \n",
8484
" sd=sd_depth[df['species_enc']], \n",
8585
" observed=df['beak_depth'])\n",
8686
" \n",
87-
" length = pm.StudentT('length',\n",
87+
" length = pm.StudentT('beak_length',\n",
8888
" nu=nu,\n",
8989
" mu=mean_length[df['species_enc']],\n",
9090
" sd=sd_length[df['species_enc']],\n",
91-
" observed=df['beak_length'])\n",
92-
" \n",
93-
" shape = pm.Deterministic('shape', depth / length)"
91+
" observed=df['beak_length'])"
9492
]
9593
},
9694
{
@@ -131,6 +129,73 @@
131129
"samples"
132130
]
133131
},
132+
{
133+
"cell_type": "markdown",
134+
"metadata": {},
135+
"source": [
136+
"PPC check for Fortis"
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"fig = plt.figure()\n",
146+
"ax1 = fig.add_subplot(121)\n",
147+
"ax2 = fig.add_subplot(122, sharex=ax1)\n",
148+
"\n",
149+
"def plot_ppc_data(samples, df, idxs, column, ax):\n",
150+
" x, y = ECDF(samples[column][:, idxs].flatten())\n",
151+
" ax.plot(x, y, label='ppc')\n",
152+
" x, y = ECDF(df.iloc[idxs][column])\n",
153+
" ax.plot(x, y, label='data')\n",
154+
" ax.set_xlabel(column)\n",
155+
" ax.set_ylabel('cumulative fraction')\n",
156+
" return ax\n",
157+
"\n",
158+
"ax1 = plot_ppc_data(samples, df, fortis_idx, 'beak_depth', ax1)\n",
159+
"ax2 = plot_ppc_data(samples, df, fortis_idx, 'beak_length', ax2)\n",
160+
"\n",
161+
"fig.suptitle('Fortis')\n",
162+
"plt.tight_layout()"
163+
]
164+
},
165+
{
166+
"cell_type": "markdown",
167+
"metadata": {},
168+
"source": [
169+
"PPC check for Scandens"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"fig = plt.figure()\n",
179+
"ax1 = fig.add_subplot(121)\n",
180+
"ax2 = fig.add_subplot(122, sharex=ax1)\n",
181+
"\n",
182+
"ax1 = plot_ppc_data(samples, df, scandens_idx, 'beak_depth', ax1)\n",
183+
"ax2 = plot_ppc_data(samples, df, scandens_idx, 'beak_length', ax2)\n",
184+
"\n",
185+
"fig.suptitle('Scandens')\n",
186+
"plt.tight_layout()"
187+
]
188+
},
189+
{
190+
"cell_type": "code",
191+
"execution_count": null,
192+
"metadata": {},
193+
"outputs": [],
194+
"source": [
195+
"plt.hist(samples['beak_depth'].flatten())\n",
196+
"plt.hist(samples['beak_length'].flatten())"
197+
]
198+
},
134199
{
135200
"cell_type": "code",
136201
"execution_count": null,
@@ -139,7 +204,7 @@
139204
"source": [
140205
"fig = plt.figure()\n",
141206
"ax = fig.add_subplot(111)\n",
142-
"x, y = ECDF((samples['depth'][:, fortis_idx] / samples['length'][:, fortis_idx]).flatten())\n",
207+
"x, y = ECDF((samples['beak_depth'][:, fortis_idx] / samples['beak_length'][:, fortis_idx]).flatten())\n",
143208
"ax.plot(x, y)\n",
144209
"x, y = ECDF(df.loc[fortis_idx, 'shape'])\n",
145210
"ax.plot(x, y)"
@@ -151,12 +216,91 @@
151216
"metadata": {},
152217
"outputs": [],
153218
"source": [
154-
"fig = plt.figure()\n",
155-
"ax = fig.add_subplot(111)\n",
156-
"x, y = ECDF(df['shape'])\n",
157-
"ax.plot(x, y, label='data')\n",
158-
"# x, y = ECDF(trace['shape'][0, :])\n",
159-
"# ax.plot(x, y, label='posterior')\n"
219+
"fig = plt.figure(figsize=(12, 4))\n",
220+
"\n",
221+
"def plot_length_depth_scatter(df, idxs, title, ax):\n",
222+
" ax.scatter(df.iloc[idxs]['beak_length'], df.iloc[idxs]['beak_depth'])\n",
223+
" ax.set_xlabel('beak_length')\n",
224+
" ax.set_ylabel('beak_depth')\n",
225+
" ax.set_title(title)\n",
226+
" return ax\n",
227+
"\n",
228+
"\n",
229+
"ax1 = fig.add_subplot(121)\n",
230+
"ax1 = plot_length_depth_scatter(df, scandens_idx, 'scandens', ax1)\n",
231+
"\n",
232+
"ax2 = fig.add_subplot(122, sharex=ax1, sharey=ax1)\n",
233+
"ax2 = plot_length_depth_scatter(df, fortis_idx, 'fortis', ax2)\n"
234+
]
235+
},
236+
{
237+
"cell_type": "markdown",
238+
"metadata": {},
239+
"source": [
240+
"Going to try a new model: we explicity model depth and length jointly, as a multivariate gaussian."
241+
]
242+
},
243+
{
244+
"cell_type": "code",
245+
"execution_count": null,
246+
"metadata": {},
247+
"outputs": [],
248+
"source": [
249+
"with pm.Model() as mv_beaks: # multivariate beak model\n",
250+
" packed_L = pm.LKJCholeskyCov('packed_L', n=2,\n",
251+
" eta=2., sd_dist=pm.HalfCauchy.dist(2.5))\n",
252+
" L = pm.expand_packed_triangular(2, packed_L)\n",
253+
" sigma = pm.Deterministic('sigma', L.dot(L.T))\n",
254+
"\n",
255+
" mu = pm.HalfNormal('mu', sd=20, shape=(2,))\n",
256+
" \n",
257+
" like = pm.MvNormal('like', mu=mu, cov=sigma, observed=df[['beak_depth', 'beak_length']].values)"
258+
]
259+
},
260+
{
261+
"cell_type": "code",
262+
"execution_count": null,
263+
"metadata": {},
264+
"outputs": [],
265+
"source": [
266+
"with mv_beaks:\n",
267+
" trace_mv = pm.sample(2000, njobs=1)"
268+
]
269+
},
270+
{
271+
"cell_type": "code",
272+
"execution_count": null,
273+
"metadata": {},
274+
"outputs": [],
275+
"source": [
276+
"pm.traceplot(trace_mv)"
277+
]
278+
},
279+
{
280+
"cell_type": "code",
281+
"execution_count": null,
282+
"metadata": {},
283+
"outputs": [],
284+
"source": [
285+
"pm.forestplot(trace_mv, varnames=['sigma'])"
286+
]
287+
},
288+
{
289+
"cell_type": "code",
290+
"execution_count": null,
291+
"metadata": {},
292+
"outputs": [],
293+
"source": [
294+
"pm.forestplot(trace_mv, varnames=['mu'])"
295+
]
296+
},
297+
{
298+
"cell_type": "code",
299+
"execution_count": null,
300+
"metadata": {},
301+
"outputs": [],
302+
"source": [
303+
"samples_mv = pm.sample_ppc(trace, model=mv_beaks)"
160304
]
161305
},
162306
{

0 commit comments

Comments
 (0)