Skip to content

Commit a6b26de

Browse files
committed
use partition to refresh the plot less often
1 parent 0b63090 commit a6b26de

1 file changed

Lines changed: 73 additions & 11 deletions

File tree

examples/river_kmeans.py

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import functools
12
import random
3+
import time
24

35
import pandas as pd
46

@@ -8,6 +10,7 @@
810
from river import cluster
911
import holoviews as hv
1012
from panel.pane.holoviews import HoloViews
13+
import panel as pn
1114
hv.extension('bokeh')
1215

1316
model = cluster.KMeans(n_clusters=3, sigma=0.1, mu=0.5)
@@ -32,14 +35,49 @@ def get_clusters(model):
3235

3336

3437
def main(viz=True):
38+
cadance = 0.01
39+
40+
ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})
41+
pipe_in = hv.streams.Pipe(data=ex)
42+
pipe_out = hv.streams.Pipe(data=ex)
43+
3544
# setup pipes
36-
cadance = 0.16 if viz else 0.01
3745
s = Stream.from_periodic(gen, cadance)
46+
47+
# Branch 0: Input/Observations
48+
obs = s.map(lambda x: pd.DataFrame([x]))
49+
50+
# Branch 1: Output/River ML clusters
3851
km = RiverTrain(model, pass_model=True)
3952
s.map(lambda x: (x,)).connect(km) # learn takes a tuple of (x,[ y[, w]])
40-
ex = pd.DataFrame({'x': [0.5], 'y': [0.5]})
41-
ooo = s.map(lambda x: pd.DataFrame([x])).to_dataframe(example=ex)
42-
out = km.map(get_clusters)
53+
clusters = km.map(get_clusters)
54+
55+
concat = functools.partial(pd.concat, ignore_index=True)
56+
57+
def accumulate(previous, new, last_lines=50):
58+
return concat([previous, new]).iloc[-last_lines:, :]
59+
60+
partition_obs = 10
61+
particion_clusters = 10
62+
backlog_obs = 100
63+
64+
# .partition is used to gather x number of points
65+
# before sending them to the plots
66+
# .accumulate allows to generate a backlog
67+
68+
(
69+
obs
70+
.partition(partition_obs)
71+
.map(concat)
72+
.accumulate(functools.partial(accumulate, last_lines=backlog_obs))
73+
.sink(pipe_in.send)
74+
)
75+
(
76+
clusters
77+
.partition(particion_clusters)
78+
.map(pd.concat)
79+
.sink(pipe_out.send)
80+
)
4381

4482
# start things
4583
s.emit(gen()) # set initial model
@@ -48,23 +86,47 @@ def main(viz=True):
4886
model.centers[i]['y'] = y
4987

5088
print("starting")
51-
s.start()
5289

5390
if viz:
5491
# plot
55-
pout = out.to_dataframe(example=ex)
56-
pl = (ooo.hvplot.scatter('x', 'y', color="blue", backlog=50) *
57-
pout.hvplot.scatter('x', 'y', color="red", backlog=3))
92+
button_start = pn.widgets.Button(name='Start')
93+
button_stop = pn.widgets.Button(name='Stop')
94+
95+
t0 = 0
96+
97+
def start(event):
98+
s.start()
99+
global t0
100+
t0 = time.time()
101+
102+
def stop(event):
103+
print(count, "events")
104+
global t0
105+
t_spent = time.time() - t0
106+
print("frequency", count[0] / t_spent, "Hz")
107+
print("Current centres", centres)
108+
print("Output centres", [list(c.values()) for c in model.centers.values()])
109+
s.stop()
110+
111+
button_start.on_click(start)
112+
button_stop.on_click(stop)
113+
114+
scatter_dmap_input = hv.DynamicMap(hv.Scatter, streams=[pipe_in]).opts(color="blue")
115+
scatter_dmap_output = hv.DynamicMap(hv.Scatter, streams=[pipe_out]).opts(color="red")
116+
pl = scatter_dmap_input * scatter_dmap_output
58117
pl.opts(xlim=(-0.2, 1.2), ylim=(-0.2, 1.2), height=600, width=600)
118+
59119
pan = HoloViews(pl)
60-
pan.show()
120+
app = pn.Row(pn.Column(button_start, button_stop), pan)
121+
app.show()
61122
else:
62-
import time
123+
s.start()
63124
time.sleep(5)
64125
print(count, "events")
126+
print("frequency", count[0] / 5, "Hz")
65127
print("Current centres", centres)
66128
print("Output centres", [list(c.values()) for c in model.centers.values()])
67-
s.stop()
129+
s.stop()
68130

69131
if __name__ == "__main__":
70132
main(viz=True)

0 commit comments

Comments
 (0)