Skip to content

Commit bf61e48

Browse files
committed
cmd/ursrv: Refactor to use CLI options, fewer global vars
1 parent b2886f1 commit bf61e48

File tree

1 file changed

+82
-67
lines changed

1 file changed

+82
-67
lines changed

cmd/ursrv/main.go

Lines changed: 82 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"time"
2626
"unicode"
2727

28+
"github.com/alecthomas/kong"
2829
_ "github.com/lib/pq" // PostgreSQL driver
2930
"github.com/oschwald/geoip2-golang"
3031
"golang.org/x/text/cases"
@@ -34,14 +35,17 @@ import (
3435
"github.com/syncthing/syncthing/lib/ur/contract"
3536
)
3637

38+
type CLI struct {
39+
UseHTTP bool `env:"UR_USE_HTTP"`
40+
Debug bool `env:"UR_DEBUG"`
41+
KeyFile string `env:"UR_KEY_FILE" default:"key.pem"`
42+
CertFile string `env:"UR_CRT_FILE" default:"crt.pem"`
43+
DBConn string `env:"UR_DB_URL" default:"postgres://user:password@localhost/ur?sslmode=disable"`
44+
Listen string `env:"UR_LISTEN" default:"0.0.0.0:8443"`
45+
GeoIPPath string `env:"UR_GEOIP" default:"GeoLite2-City.mmdb"`
46+
}
47+
3748
var (
38-
useHTTP = os.Getenv("UR_USE_HTTP") != ""
39-
debug = os.Getenv("UR_DEBUG") != ""
40-
keyFile = getEnvDefault("UR_KEY_FILE", "key.pem")
41-
certFile = getEnvDefault("UR_CRT_FILE", "crt.pem")
42-
dbConn = getEnvDefault("UR_DB_URL", "postgres://user:password@localhost/ur?sslmode=disable")
43-
listenAddr = getEnvDefault("UR_LISTEN", "0.0.0.0:8443")
44-
geoIPPath = getEnvDefault("UR_GEOIP", "GeoLite2-City.mmdb")
4549
tpl *template.Template
4650
compilerRe = regexp.MustCompile(`\(([A-Za-z0-9()., -]+) \w+-\w+(?:| android| default)\) ([\[email protected]]+)`)
4751
progressBarClass = []string{"", "progress-bar-success", "progress-bar-info", "progress-bar-warning", "progress-bar-danger"}
@@ -159,6 +163,9 @@ func main() {
159163
log.SetFlags(log.Ltime | log.Ldate | log.Lshortfile)
160164
log.SetOutput(os.Stdout)
161165

166+
var cli CLI
167+
kong.Parse(&cli)
168+
162169
// Template
163170

164171
fd, err := os.Open("static/index.html")
@@ -174,7 +181,7 @@ func main() {
174181

175182
// DB
176183

177-
db, err := sql.Open("postgres", dbConn)
184+
db, err := sql.Open("postgres", cli.DBConn)
178185
if err != nil {
179186
log.Fatalln("database:", err)
180187
}
@@ -186,11 +193,11 @@ func main() {
186193
// TLS & Listening
187194

188195
var listener net.Listener
189-
if useHTTP {
190-
listener, err = net.Listen("tcp", listenAddr)
196+
if cli.UseHTTP {
197+
listener, err = net.Listen("tcp", cli.Listen)
191198
} else {
192199
var cert tls.Certificate
193-
cert, err = tls.LoadX509KeyPair(certFile, keyFile)
200+
cert, err = tls.LoadX509KeyPair(cli.CertFile, cli.KeyFile)
194201
if err != nil {
195202
log.Fatalln("tls:", err)
196203
}
@@ -199,112 +206,120 @@ func main() {
199206
Certificates: []tls.Certificate{cert},
200207
SessionTicketsDisabled: true,
201208
}
202-
listener, err = tls.Listen("tcp", listenAddr, cfg)
209+
listener, err = tls.Listen("tcp", cli.Listen, cfg)
203210
}
204211
if err != nil {
205212
log.Fatalln("listen:", err)
206213
}
207214

208-
srv := http.Server{
209-
ReadTimeout: 5 * time.Second,
210-
WriteTimeout: 15 * time.Second,
215+
srv := &server{
216+
db: db,
217+
debug: cli.Debug,
218+
geoIPPath: cli.GeoIPPath,
211219
}
212-
213-
http.HandleFunc("/", withDB(db, rootHandler))
214-
http.HandleFunc("/newdata", withDB(db, newDataHandler))
215-
http.HandleFunc("/summary.json", withDB(db, summaryHandler))
216-
http.HandleFunc("/movement.json", withDB(db, movementHandler))
217-
http.HandleFunc("/performance.json", withDB(db, performanceHandler))
218-
http.HandleFunc("/blockstats.json", withDB(db, blockStatsHandler))
219-
http.HandleFunc("/locations.json", withDB(db, locationsHandler))
220+
http.HandleFunc("/", srv.rootHandler)
221+
http.HandleFunc("/newdata", srv.newDataHandler)
222+
http.HandleFunc("/summary.json", srv.summaryHandler)
223+
http.HandleFunc("/movement.json", srv.movementHandler)
224+
http.HandleFunc("/performance.json", srv.performanceHandler)
225+
http.HandleFunc("/blockstats.json", srv.blockStatsHandler)
226+
http.HandleFunc("/locations.json", srv.locationsHandler)
220227
http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static"))))
221228

222-
go cacheRefresher(db)
229+
go srv.cacheRefresher()
223230

224-
err = srv.Serve(listener)
231+
httpSrv := http.Server{
232+
ReadTimeout: 5 * time.Second,
233+
WriteTimeout: 15 * time.Second,
234+
}
235+
err = httpSrv.Serve(listener)
225236
if err != nil {
226237
log.Fatalln("https:", err)
227238
}
228239
}
229240

230-
var (
241+
type server struct {
242+
debug bool
243+
db *sql.DB
244+
geoIPPath string
245+
246+
cacheMut sync.Mutex
231247
cachedIndex []byte
232248
cachedLocations []byte
233249
cacheTime time.Time
234-
cacheMut sync.Mutex
235-
)
250+
}
236251

237252
const maxCacheTime = 15 * time.Minute
238253

239-
func cacheRefresher(db *sql.DB) {
254+
func (s *server) cacheRefresher() {
240255
ticker := time.NewTicker(maxCacheTime - time.Minute)
241256
defer ticker.Stop()
242257
for ; true; <-ticker.C {
243-
cacheMut.Lock()
244-
if err := refreshCacheLocked(db); err != nil {
258+
s.cacheMut.Lock()
259+
if err := s.refreshCacheLocked(); err != nil {
245260
log.Println(err)
246261
}
247-
cacheMut.Unlock()
262+
s.cacheMut.Unlock()
248263
}
249264
}
250265

251-
func refreshCacheLocked(db *sql.DB) error {
252-
rep := getReport(db)
266+
func (s *server) refreshCacheLocked() error {
267+
rep := getReport(s.db, s.geoIPPath)
253268
buf := new(bytes.Buffer)
254269
err := tpl.Execute(buf, rep)
255270
if err != nil {
256271
return err
257272
}
258-
cachedIndex = buf.Bytes()
259-
cacheTime = time.Now()
273+
s.cachedIndex = buf.Bytes()
274+
s.cacheTime = time.Now()
260275

261276
locs := rep["locations"].(map[location]int)
262277
wlocs := make([]weightedLocation, 0, len(locs))
263278
for loc, w := range locs {
264279
wlocs = append(wlocs, weightedLocation{loc, w})
265280
}
266-
cachedLocations, _ = json.Marshal(wlocs)
281+
s.cachedLocations, _ = json.Marshal(wlocs)
267282
return nil
268283
}
269284

270-
func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
285+
func (s *server) rootHandler(w http.ResponseWriter, r *http.Request) {
271286
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
272-
cacheMut.Lock()
273-
defer cacheMut.Unlock()
287+
s.cacheMut.Lock()
288+
defer s.cacheMut.Unlock()
274289

275-
if time.Since(cacheTime) > maxCacheTime {
276-
if err := refreshCacheLocked(db); err != nil {
290+
if time.Since(s.cacheTime) > maxCacheTime {
291+
if err := s.refreshCacheLocked(); err != nil {
277292
log.Println(err)
278293
http.Error(w, "Template Error", http.StatusInternalServerError)
279294
return
280295
}
281296
}
282297

283298
w.Header().Set("Content-Type", "text/html; charset=utf-8")
284-
w.Write(cachedIndex)
299+
w.Write(s.cachedIndex)
285300
} else {
286301
http.Error(w, "Not found", 404)
287302
return
288303
}
289304
}
290305

291-
func locationsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
292-
cacheMut.Lock()
293-
defer cacheMut.Unlock()
306+
func (s *server) locationsHandler(w http.ResponseWriter, _ *http.Request) {
307+
s.cacheMut.Lock()
308+
defer s.cacheMut.Unlock()
294309

295-
if time.Since(cacheTime) > maxCacheTime {
296-
if err := refreshCacheLocked(db); err != nil {
310+
if time.Since(s.cacheTime) > maxCacheTime {
311+
if err := s.refreshCacheLocked(); err != nil {
297312
log.Println(err)
298313
http.Error(w, "Template Error", http.StatusInternalServerError)
299314
return
300315
}
301316
}
302317

303318
w.Header().Set("Content-Type", "application/json; charset=utf-8")
304-
w.Write(cachedLocations)
319+
w.Write(s.cachedLocations)
305320
}
306321

307-
func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
322+
func (s *server) newDataHandler(w http.ResponseWriter, r *http.Request) {
308323
defer r.Body.Close()
309324

310325
addr := r.Header.Get("X-Forwarded-For")
@@ -330,7 +345,7 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
330345
bs, _ := io.ReadAll(lr)
331346
if err := json.Unmarshal(bs, &rep); err != nil {
332347
log.Println("decode:", err)
333-
if debug {
348+
if s.debug {
334349
log.Printf("%s", bs)
335350
}
336351
http.Error(w, "JSON Decode Error", http.StatusInternalServerError)
@@ -339,38 +354,38 @@ func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
339354

340355
if err := rep.Validate(); err != nil {
341356
log.Println("validate:", err)
342-
if debug {
357+
if s.debug {
343358
log.Printf("%#v", rep)
344359
}
345360
http.Error(w, "Validation Error", http.StatusInternalServerError)
346361
return
347362
}
348363

349-
if err := insertReport(db, rep); err != nil {
364+
if err := insertReport(s.db, rep); err != nil {
350365
if err.Error() == `pq: duplicate key value violates unique constraint "uniqueidjsonindex"` {
351366
// We already have a report today for the same unique ID; drop
352367
// this one without complaining.
353368
return
354369
}
355370
log.Println("insert:", err)
356-
if debug {
371+
if s.debug {
357372
log.Printf("%#v", rep)
358373
}
359374
http.Error(w, "Database Error", http.StatusInternalServerError)
360375
return
361376
}
362377
}
363378

364-
func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
379+
func (s *server) summaryHandler(w http.ResponseWriter, r *http.Request) {
365380
min, _ := strconv.Atoi(r.URL.Query().Get("min"))
366-
s, err := getSummary(db, min)
381+
sum, err := getSummary(s.db, min)
367382
if err != nil {
368383
log.Println("summaryHandler:", err)
369384
http.Error(w, "Database Error", http.StatusInternalServerError)
370385
return
371386
}
372387

373-
bs, err := s.MarshalJSON()
388+
bs, err := sum.MarshalJSON()
374389
if err != nil {
375390
log.Println("summaryHandler:", err)
376391
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -381,15 +396,15 @@ func summaryHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) {
381396
w.Write(bs)
382397
}
383398

384-
func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
385-
s, err := getMovement(db)
399+
func (s *server) movementHandler(w http.ResponseWriter, _ *http.Request) {
400+
mov, err := getMovement(s.db)
386401
if err != nil {
387402
log.Println("movementHandler:", err)
388403
http.Error(w, "Database Error", http.StatusInternalServerError)
389404
return
390405
}
391406

392-
bs, err := json.Marshal(s)
407+
bs, err := json.Marshal(mov)
393408
if err != nil {
394409
log.Println("movementHandler:", err)
395410
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -400,15 +415,15 @@ func movementHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
400415
w.Write(bs)
401416
}
402417

403-
func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
404-
s, err := getPerformance(db)
418+
func (s *server) performanceHandler(w http.ResponseWriter, _ *http.Request) {
419+
perf, err := getPerformance(s.db)
405420
if err != nil {
406421
log.Println("performanceHandler:", err)
407422
http.Error(w, "Database Error", http.StatusInternalServerError)
408423
return
409424
}
410425

411-
bs, err := json.Marshal(s)
426+
bs, err := json.Marshal(perf)
412427
if err != nil {
413428
log.Println("performanceHandler:", err)
414429
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -419,15 +434,15 @@ func performanceHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
419434
w.Write(bs)
420435
}
421436

422-
func blockStatsHandler(db *sql.DB, w http.ResponseWriter, _ *http.Request) {
423-
s, err := getBlockStats(db)
437+
func (s *server) blockStatsHandler(w http.ResponseWriter, _ *http.Request) {
438+
blocks, err := getBlockStats(s.db)
424439
if err != nil {
425440
log.Println("blockStatsHandler:", err)
426441
http.Error(w, "Database Error", http.StatusInternalServerError)
427442
return
428443
}
429444

430-
bs, err := json.Marshal(s)
445+
bs, err := json.Marshal(blocks)
431446
if err != nil {
432447
log.Println("blockStatsHandler:", err)
433448
http.Error(w, "JSON Encode Error", http.StatusInternalServerError)
@@ -513,7 +528,7 @@ type weightedLocation struct {
513528
Weight int `json:"weight"`
514529
}
515530

516-
func getReport(db *sql.DB) map[string]interface{} {
531+
func getReport(db *sql.DB, geoIPPath string) map[string]interface{} {
517532
geoip, err := geoip2.Open(geoIPPath)
518533
if err != nil {
519534
log.Println("opening geoip db", err)

0 commit comments

Comments
 (0)