Skip to content

Commit 1910bb5

Browse files
author
Martin Durant
committed
better working river example
1 parent 91f8e63 commit 1910bb5

1 file changed

Lines changed: 26 additions & 15 deletions

File tree

examples/river_kmeans.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414
import tornado.ioloop
1515
hv.extension('bokeh')
1616

17-
model = cluster.STREAMKMeans(n_clusters=3)
17+
model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)
1818
centres = [[random.random(), random.random()] for _ in range(3)]
19+
count = [0]
1920

20-
def gen(split_chance=0.01):
21+
def gen(move_chance=0.05):
2122
centre = int(random.random() * 3) # 3x faster than random.randint(0, 2)
22-
if random.random() < split_chance:
23-
centres[centre] = [random.random(), random.random()]
23+
if random.random() < move_chance:
24+
centres[centre][0] += random.random() / 5 - 0.1
25+
centres[centre][1] += random.random() / 5 - 0.1
2426
value = {'x': random.random() / 20 + centres[centre][0],
2527
'y': random.random() / 20 + centres[centre][1]}
28+
count[0] += 1
2629
return value
2730

2831

@@ -32,9 +35,10 @@ def get_clusters(model):
3235
return pd.DataFrame(data, index=range(3))
3336

3437

35-
def main():
38+
def main(viz=True):
3639
# setup pipes
37-
s = Stream.from_periodic(gen, 0.05)
40+
cadance = 0.16 if viz else 0.01
41+
s = Stream.from_periodic(gen, cadance)
3842
km = RiverTrain(model, pass_model=True)
3943
s.map(lambda x: (x,)).connect(km) # learn takes a tuple of (x,[ y[, w]])
4044
ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})
@@ -47,16 +51,23 @@ def main():
4751
model.centers[i]['x'] = x
4852
model.centers[i]['y'] = y
4953

54+
print("starting")
5055
s.start()
5156

52-
# plot
53-
pout = out.to_dataframe(example=ex)
54-
pl = (ooo.hvplot.scatter('x', 'y', color="blue", backlog=100) *
55-
pout.hvplot.scatter('x', 'y', color="red", backlog=3))
56-
pl.opts(xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), height=600, width=600)
57-
pan = panel.pane.holoviews.HoloViews(pl)
58-
pan.show()
59-
57+
if viz:
58+
# plot
59+
pout = out.to_dataframe(example=ex)
60+
pl = (ooo.hvplot.scatter('x', 'y', color="blue", backlog=50) *
61+
pout.hvplot.scatter('x', 'y', color="red", backlog=3))
62+
pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)
63+
pan = panel.pane.holoviews.HoloViews(pl)
64+
pan.show(threaded=True)
65+
else:
66+
import time
67+
time.sleep(5)
68+
print(count, "events")
69+
print("Current centres", centres)
70+
print("Output centres", [list(c.values()) for c in model.centers.values()])
6071

6172
if __name__ == "__main__":
62-
main()
73+
main(viz=True)

0 commit comments

Comments
 (0)