Skip to content

Commit 4f877f4

Browse files
author
Martin Durant
committed
Fix notebook example
1 parent 1910bb5 commit 4f877f4

2 files changed

Lines changed: 26 additions & 29 deletions

File tree

examples/river_kmeans.ipynb

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,13 @@
1010
"import random\n",
1111
"\n",
1212
"import pandas as pd\n",
13-
"import tornado.ioloop\n",
1413
"\n",
1514
"from streamz import Stream\n",
1615
"import hvplot.streamz\n",
1716
"from streamz.river import RiverTrain\n",
1817
"from river import cluster\n",
1918
"import holoviews as hv\n",
20-
"import panel as pn\n",
2119
"from panel.pane.holoviews import HoloViews\n",
22-
"import tornado.ioloop\n",
2320
"hv.extension('bokeh')"
2421
]
2522
},
@@ -30,13 +27,14 @@
3027
"metadata": {},
3128
"outputs": [],
3229
"source": [
33-
"model = cluster.STREAMKMeans(n_clusters=3)\n",
30+
"model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)\n",
3431
"centres = [[random.random(), random.random()] for _ in range(3)]\n",
3532
"\n",
36-
"def gen(split_chance=0.01):\n",
33+
"def gen(move_chance=0.05):\n",
3734
" centre = int(random.random() * 3) # 3x faster than random.randint(0, 2)\n",
38-
" if random.random() < split_chance:\n",
39-
" centres[centre] = [random.random(), random.random()]\n",
35+
" if random.random() < move_chance:\n",
36+
" centres[centre][0] += random.random() / 5 - 0.1\n",
37+
" centres[centre][1] += random.random() / 5 - 0.1\n",
4038
" value = {'x': random.random() / 20 + centres[centre][0],\n",
4139
" 'y': random.random() / 20 + centres[centre][1]}\n",
4240
" return value\n",
@@ -55,31 +53,38 @@
5553
"metadata": {},
5654
"outputs": [],
5755
"source": [
58-
"s = Stream.from_periodic(gen, 0.05)\n",
56+
"s = Stream.from_periodic(gen, 0.03)\n",
5957
"km = RiverTrain(model, pass_model=True)\n",
6058
"s.map(lambda x: (x,)).connect(km) # learn takes a tuple of (x,[ y[, w]])\n",
6159
"ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})\n",
6260
"ooo = s.map(lambda x: pd.DataFrame([x])).to_dataframe(example=ex)\n",
6361
"out = km.map(get_clusters)\n",
64-
"s.emit(gen()) # set initial model"
62+
"\n",
63+
"# start things\n",
64+
"s.emit(gen()) # set initial model\n",
65+
"for i, (x, y) in enumerate(centres):\n",
66+
" model.centers[i]['x'] = x\n",
67+
" model.centers[i]['y'] = y\n"
6568
]
6669
},
6770
{
6871
"cell_type": "code",
6972
"execution_count": null,
70-
"id": "c24d2363",
73+
"id": "1b4de451",
7174
"metadata": {},
7275
"outputs": [],
7376
"source": [
74-
"for i, (x, y) in enumerate(centres):\n",
75-
" model.centers[i]['x'] = x\n",
76-
" model.centers[i]['y'] = y\n"
77+
"pout = out.to_dataframe(example=ex)\n",
78+
"pl = (ooo.hvplot.scatter('x', 'y', color=\"blue\", backlog=50) *\n",
79+
" pout.hvplot.scatter('x', 'y', color=\"red\", backlog=3))\n",
80+
"pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)\n",
81+
"pl"
7782
]
7883
},
7984
{
8085
"cell_type": "code",
8186
"execution_count": null,
82-
"id": "1b4de451",
87+
"id": "c24d2363",
8388
"metadata": {},
8489
"outputs": [],
8590
"source": [
@@ -89,25 +94,20 @@
8994
{
9095
"cell_type": "code",
9196
"execution_count": null,
92-
"id": "8f356afa",
97+
"id": "18cfd94e",
9398
"metadata": {},
9499
"outputs": [],
95100
"source": [
96-
"pout = out.to_dataframe(example=ex)\n",
97-
"pl = (ooo.hvplot.scatter('x', 'y', color=\"blue\", backlog=100) * \n",
98-
" pout.hvplot.scatter('x', 'y', color=\"red\", backlog=3))\n",
99-
"pl.opts(xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), height=600, width=600)"
101+
"s.stop()"
100102
]
101103
},
102104
{
103105
"cell_type": "code",
104106
"execution_count": null,
105-
"id": "18cfd94e",
107+
"id": "4537495c",
106108
"metadata": {},
107109
"outputs": [],
108-
"source": [
109-
"s.stop()"
110-
]
110+
"source": []
111111
}
112112
],
113113
"metadata": {

examples/river_kmeans.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
import random
22

33
import pandas as pd
4-
import panel.pane.holoviews
5-
import tornado.ioloop
64

75
from streamz import Stream
86
import hvplot.streamz
97
from streamz.river import RiverTrain
108
from river import cluster
119
import holoviews as hv
12-
import panel as pn
1310
from panel.pane.holoviews import HoloViews
14-
import tornado.ioloop
1511
hv.extension('bokeh')
1612

1713
model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)
@@ -60,14 +56,15 @@ def main(viz=True):
6056
pl = (ooo.hvplot.scatter('x', 'y', color="blue", backlog=50) *
6157
pout.hvplot.scatter('x', 'y', color="red", backlog=3))
6258
pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)
63-
pan = panel.pane.holoviews.HoloViews(pl)
64-
pan.show(threaded=True)
59+
pan = HoloViews(pl)
60+
pan.show()
6561
else:
6662
import time
6763
time.sleep(5)
6864
print(count, "events")
6965
print("Current centres", centres)
7066
print("Output centres", [list(c.values()) for c in model.centers.values()])
67+
s.stop()
7168

7269
if __name__ == "__main__":
7370
main(viz=True)

0 commit comments

Comments
 (0)