Skip to content

Commit 6b6a130

Browse files
committed
refactor: use argument struct train actions
1 parent 7ad9f1c commit 6b6a130

2 files changed

Lines changed: 57 additions & 66 deletions

File tree

cmd/ba/main.go

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -97,35 +97,24 @@ func makeapp() *cli.App {
9797
cli.IntFlag{Name: "duration", Usage: "If set, game will stop after this durarion (in seconds)"},
9898
},
9999
Action: func(c *cli.Context) error {
100-
tps := c.Int("tps")
101-
host := c.String("host")
102-
agents := c.StringSlice("agent")
103-
port := c.Int("port")
104-
vizhost := c.String("viz-host")
105-
recordFile := c.String("record-file")
106-
mapName := c.String("map")
107-
nobrowser := c.Bool("no-browser")
108-
isDebug := c.Bool("debug")
109-
isQuiet := c.Bool("quiet")
110-
shouldProfile := c.Bool("profile")
111-
dumpRaw := c.Bool("dump-raw-comm")
112-
duration := c.Int("duration")
113-
114-
showUsage, err := train.TrainAction(
115-
tps,
116-
host,
117-
port,
118-
vizhost,
119-
nobrowser,
120-
recordFile,
121-
agents,
122-
isDebug,
123-
isQuiet,
124-
mapName,
125-
shouldProfile,
126-
dumpRaw,
127-
duration,
128-
)
100+
101+
args := train.TrainActionArguments{
102+
Tps: c.Int("tps"),
103+
Host: c.String("host"),
104+
Agentimages: c.StringSlice("agent"),
105+
Vizport: c.Int("port"),
106+
Vizhost: c.String("viz-host"),
107+
RecordFile: c.String("record-file"),
108+
MapName: c.String("map"),
109+
Nobrowser: c.Bool("no-browser"),
110+
IsDebug: c.Bool("debug"),
111+
IsQuiet: c.Bool("quiet"),
112+
ShouldProfile: c.Bool("profile"),
113+
DumpRaw: c.Bool("dump-raw-comm"),
114+
DurationSeconds: c.Int("duration"),
115+
}
116+
117+
showUsage, err := train.TrainAction(args)
129118

130119
if err != nil {
131120
commandFailWith("train", showUsage, c, err)

subcommand/train/main.go

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,25 @@ const (
5151
TIME_BEFORE_FORCE_QUIT = 5 * time.Second
5252
)
5353

54-
func TrainAction(
55-
tps int,
56-
host string,
57-
vizport int,
58-
vizhost string,
59-
nobrowser bool,
60-
recordFile string,
61-
agentimages []string,
62-
isDebug bool,
63-
isQuiet bool,
64-
mapName string,
65-
shouldProfile,
66-
dumpRaw bool,
67-
durationSeconds int,
68-
) (bool, error) {
69-
70-
if shouldProfile {
54+
type TrainActionArguments struct {
55+
Tps int
56+
Host string
57+
Vizport int
58+
Vizhost string
59+
Nobrowser bool
60+
RecordFile string
61+
Agentimages []string
62+
IsDebug bool
63+
IsQuiet bool
64+
MapName string
65+
ShouldProfile bool
66+
DumpRaw bool
67+
DurationSeconds int
68+
}
69+
70+
func TrainAction(args TrainActionArguments) (bool, error) {
71+
72+
if args.ShouldProfile {
7173
f, err := os.Create("./cpu.prof")
7274
if err != nil {
7375
log.Fatal("could not create CPU profile: ", err)
@@ -80,27 +82,27 @@ func TrainAction(
8082

8183
var gameDuration *time.Duration
8284

83-
if durationSeconds > 0 {
84-
d := time.Duration(durationSeconds) * time.Second
85+
if args.DurationSeconds > 0 {
86+
d := time.Duration(args.DurationSeconds) * time.Second
8587
gameDuration = &d
8688
}
8789

8890
shutdownChan := make(chan bool)
8991
debug := func(str string) {}
9092

91-
if isDebug {
93+
if args.IsDebug {
9294
debug = func(str string) {
9395
fmt.Printf(DebugColor("[debug] %s\n"), str)
9496
}
9597
}
9698

97-
if host == "" {
99+
if args.Host == "" {
98100
ip, err := utils.GetCurrentIP()
99101
utils.Check(err, "Could not determine host IP; you can specify using the `--host` flag.")
100-
host = ip
102+
args.Host = ip
101103
}
102104

103-
if len(agentimages) == 0 {
105+
if len(args.Agentimages) == 0 {
104106
return SHOW_USAGE, bettererrors.New("No agents were specified")
105107
}
106108

@@ -110,24 +112,24 @@ func TrainAction(
110112
brokerclient, err := NewMemoryMessageClient()
111113
utils.Check(err, "ERROR: Could not connect to messagebroker")
112114

113-
mappack, errMappack := mappack.UnzipAndGetHandles(mapcmd.GetMapLocation(mapName))
115+
mappack, errMappack := mappack.UnzipAndGetHandles(mapcmd.GetMapLocation(args.MapName))
114116
if errMappack != nil {
115117
utils.FailWith(errMappack)
116118
}
117119

118-
gamedescription, err := NewMockGame(tps, mappack)
120+
gamedescription, err := NewMockGame(args.Tps, mappack)
119121
if err != nil {
120122
utils.FailWith(err)
121123
}
122124

123125
game := deathmatch.NewDeathmatchGame(gamedescription)
124126

125-
orchestrator := container.MakeLocalContainerOrchestrator(host)
127+
orchestrator := container.MakeLocalContainerOrchestrator(args.Host)
126128

127129
arenaServerUUID := ""
128130

129131
srv := arenaserver.NewServer(
130-
host,
132+
args.Host,
131133
orchestrator,
132134
gamedescription,
133135
game,
@@ -136,7 +138,7 @@ func TrainAction(
136138
gameDuration,
137139
)
138140

139-
for _, dockerImageName := range agentimages {
141+
for _, dockerImageName := range args.Agentimages {
140142
agentManifest, err := types.GetAgentManifestByDockerImageName(dockerImageName, orchestrator)
141143
if err != nil {
142144
return DONT_SHOW_USAGE, err
@@ -157,15 +159,15 @@ func TrainAction(
157159

158160
switch t := msg.(type) {
159161
case arenaserver.EventStatusGameUpdate:
160-
if !isQuiet {
162+
if !args.IsQuiet {
161163
fmt.Printf(GameColor("[game] %s\n"), t.Status)
162164
}
163165

164166
case arenaserver.EventAgentLog:
165167
fmt.Printf(AgentColor("[agent] %s\n"), t.Value)
166168

167169
case arenaserver.EventLog:
168-
if !isQuiet {
170+
if !args.IsQuiet {
169171
fmt.Printf(LogColor("[log] %s\n"), t.Value)
170172
}
171173

@@ -182,7 +184,7 @@ func TrainAction(
182184
fmt.Printf(HeadsUpColor("[headsup] %s\n"), t.Value)
183185

184186
case arenaserver.EventRawComm:
185-
if dumpRaw {
187+
if args.DumpRaw {
186188
fmt.Printf(AgentColor("[agent] %s\n"), t.Value)
187189
}
188190

@@ -211,8 +213,8 @@ func TrainAction(
211213
go common.StreamState(srv, brokerclient, "trainer")
212214

213215
var recorder recording.RecorderInterface = recording.MakeEmptyRecorder()
214-
if recordFile != "" {
215-
recorder = recording.MakeSingleArenaRecorder(recordFile)
216+
if args.RecordFile != "" {
217+
recorder = recording.MakeSingleArenaRecorder(args.RecordFile)
216218
}
217219

218220
recorder.RecordMetadata(gamedescription.GetId(), gamedescription.GetMapContainer())
@@ -230,8 +232,8 @@ func TrainAction(
230232
vizgames[0] = viztypes.NewVizGame(game, gamedescription)
231233

232234
vizservice := visualization.NewVizService(
233-
vizhost+":"+strconv.Itoa(vizport),
234-
mapName,
235+
args.Vizhost+":"+strconv.Itoa(args.Vizport),
236+
args.MapName,
235237
func() ([]*viztypes.VizGame, error) { return vizgames, nil },
236238
recorder,
237239
mappack,
@@ -245,9 +247,9 @@ func TrainAction(
245247
utils.FailWith(startErr)
246248
}
247249

248-
url := "http://" + vizhost + ":" + strconv.Itoa(vizport) + "/arena/1"
250+
url := "http://" + args.Vizhost + ":" + strconv.Itoa(args.Vizport) + "/arena/1"
249251

250-
if !nobrowser {
252+
if !args.Nobrowser {
251253
open.Run(url)
252254
}
253255

0 commit comments

Comments
 (0)