1414import tornado .ioloop
1515hv .extension ('bokeh' )
1616
17- model = cluster .STREAMKMeans (n_clusters = 3 )
17+ model = cluster .KMeans (n_clusters = 3 , sigma = 0.1 , mu = 0.5 )
1818centres = [[random .random (), random .random ()] for _ in range (3 )]
19+ count = [0 ]
1920
20- def gen (split_chance = 0.01 ):
21+ def gen (move_chance = 0.05 ):
2122 centre = int (random .random () * 3 ) # 3x faster than random.randint(0, 2)
22- if random .random () < split_chance :
23- centres [centre ] = [random .random (), random .random ()]
23+ if random .random () < move_chance :
24+ centres [centre ][0 ] += random .random () / 5 - 0.1
25+ centres [centre ][1 ] += random .random () / 5 - 0.1
2426 value = {'x' : random .random () / 20 + centres [centre ][0 ],
2527 'y' : random .random () / 20 + centres [centre ][1 ]}
28+ count [0 ] += 1
2629 return value
2730
2831
@@ -32,9 +35,10 @@ def get_clusters(model):
3235 return pd .DataFrame (data , index = range (3 ))
3336
3437
35- def main ():
38+ def main (viz = True ):
3639 # setup pipes
37- s = Stream .from_periodic (gen , 0.05 )
40+ cadance = 0.16 if viz else 0.01
41+ s = Stream .from_periodic (gen , cadance )
3842 km = RiverTrain (model , pass_model = True )
3943 s .map (lambda x : (x ,)).connect (km ) # learn takes a tuple of (x,[ y[, w]])
4044 ex = pd .DataFrame ({'x' : [0.5 ], 'y' : [0.5 ]})
@@ -47,16 +51,23 @@ def main():
4751 model .centers [i ]['x' ] = x
4852 model .centers [i ]['y' ] = y
4953
54+ print ("starting" )
5055 s .start ()
5156
52- # plot
53- pout = out .to_dataframe (example = ex )
54- pl = (ooo .hvplot .scatter ('x' , 'y' , color = "blue" , backlog = 100 ) *
55- pout .hvplot .scatter ('x' , 'y' , color = "red" , backlog = 3 ))
56- pl .opts (xlim = (- 0.5 , 1.5 ), ylim = (- 0.5 , 1.5 ), height = 600 , width = 600 )
57- pan = panel .pane .holoviews .HoloViews (pl )
58- pan .show ()
59-
57+ if viz :
58+ # plot
59+ pout = out .to_dataframe (example = ex )
60+ pl = (ooo .hvplot .scatter ('x' , 'y' , color = "blue" , backlog = 50 ) *
61+ pout .hvplot .scatter ('x' , 'y' , color = "red" , backlog = 3 ))
62+ pl .opts (xlim = (- 0.2 , 1.2 ), ylim = (- 0.2 , 1.2 ), height = 600 , width = 600 )
63+ pan = panel .pane .holoviews .HoloViews (pl )
64+ pan .show (threaded = True )
65+ else :
66+ import time
67+ time .sleep (5 )
68+ print (count , "events" )
69+ print ("Current centres" , centres )
70+ print ("Output centres" , [list (c .values ()) for c in model .centers .values ()])
6071
6172if __name__ == "__main__" :
62- main ()
73+ main (viz = True )
0 commit comments