Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/config/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ func (cfg *Config) Clone() *Config {
}

func (cfg *Config) Check() error {

if cfg.Workdir == "" {
d, err := os.Getwd()
if err != nil {
Expand Down
9 changes: 5 additions & 4 deletions pkg/manager/config/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,17 @@ func (e *ConfigManager) Close() error {
e.cancel()
e.cancel = nil
}
if e.wch != nil {
wcherr = e.wch.Close()
e.wch = nil
}
e.sts.Lock()
for _, ch := range e.sts.listeners {
close(ch)
}
e.sts.listeners = nil
e.sts.Unlock()
e.wg.Wait()
// close after all goroutines are done
if e.wch != nil {
wcherr = e.wch.Close()
e.wch = nil
}
return wcherr
}
3 changes: 2 additions & 1 deletion pkg/manager/infosync/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ func (is *InfoSyncer) getTopologyInfo(cfg *config.Config) (*TopologyInfo, error)
s = ""
}
dir := path.Dir(s)
ip, port, err := net.SplitHostPort(cfg.Proxy.Addr)
addrs := strings.Split(cfg.Proxy.Addr, ",")
ip, port, err := net.SplitHostPort(addrs[0])
if err != nil {
return nil, errors.WithStack(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/metrics/grafana/tiproxy_summary.json
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@
"expr": "tiproxy_server_connections{k8s_cluster=\"$k8s_cluster\", tidb_cluster=\"$tidb_cluster\", instance=~\"$instance\"}",
"format": "time_series",
"intervalFactor": 2,
"legendFormat": "{{instance}}",
"legendFormat": "{{instance}} | {{addr}}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to adapt to that: Gateway doesn't use TiProxy metrics.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I just saw the comment #393 (comment)

"refId": "A"
},
{
Expand Down
2 changes: 1 addition & 1 deletion pkg/metrics/grafana/tiproxy_summary.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ local connectionP = graphPanel.new(
.addTarget(
prometheus.target(
'tiproxy_server_connections{k8s_cluster="$k8s_cluster", tidb_cluster="$tidb_cluster", instance=~"$instance"}',
legendFormat='{{instance}}',
legendFormat='{{instance}} | {{addr}}',
)
)
.addTarget(
Expand Down
5 changes: 3 additions & 2 deletions pkg/metrics/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@ import (

const (
LblType = "type"
LblAddr = "addr"

EventStart = "start"
EventClose = "close"
)

var (
ConnGauge = prometheus.NewGauge(
ConnGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Namespace: ModuleProxy,
Subsystem: LabelServer,
Name: "connections",
Help: "Number of connections.",
})
}, []string{LblAddr})

MaxProcsGauge = prometheus.NewGauge(
prometheus.GaugeOpts{
Expand Down
1 change: 1 addition & 0 deletions pkg/proxy/backend/handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type ConnContextKey string
const (
ConnContextKeyTLSState ConnContextKey = "tls-state"
ConnContextKeyConnID ConnContextKey = "conn-id"
ConnContextKeyConnAddr ConnContextKey = "conn-addr"
)

type ErrorSource int
Expand Down
3 changes: 2 additions & 1 deletion pkg/proxy/client/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ type ClientConnection struct {
}

func NewClientConnection(logger *zap.Logger, conn net.Conn, frontendTLSConfig *tls.Config, backendTLSConfig *tls.Config,
hsHandler backend.HandshakeHandler, connID uint64, bcConfig *backend.BCConfig) *ClientConnection {
hsHandler backend.HandshakeHandler, connID uint64, addr string, bcConfig *backend.BCConfig) *ClientConnection {
bemgr := backend.NewBackendConnManager(logger.Named("be"), hsHandler, connID, bcConfig)
bemgr.SetValue(backend.ConnContextKeyConnAddr, addr)
opts := make([]pnet.PacketIOption, 0, 2)
opts = append(opts, pnet.WithWrapError(backend.ErrClientConn))
if bcConfig.ProxyProtocol {
Expand Down
69 changes: 40 additions & 29 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ package proxy

import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -37,7 +39,8 @@ type serverState struct {
}

type SQLServer struct {
listener net.Listener
listeners []net.Listener
addrs []string
logger *zap.Logger
certMgr *cert.CertManager
hsHandler backend.HandshakeHandler
Expand Down Expand Up @@ -65,9 +68,14 @@ func NewSQLServer(logger *zap.Logger, cfg config.ProxyServer, certMgr *cert.Cert

s.reset(&cfg.ProxyServerOnline)

s.listener, err = net.Listen("tcp", cfg.Addr)
if err != nil {
return nil, err
s.addrs = strings.Split(cfg.Addr, ",")
fmt.Printf("xhe %s\n", s.addrs)
s.listeners = make([]net.Listener, len(s.addrs))
for i, addr := range s.addrs {
s.listeners[i], err = net.Listen("tcp", addr)
if err != nil {
return nil, err
}
}

return s, nil
Expand Down Expand Up @@ -104,31 +112,34 @@ func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) {
}
})

s.wg.Run(func() {
for {
select {
case <-ctx.Done():
return
default:
conn, err := s.listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
for i := range s.listeners {
j := i
s.wg.Run(func() {
for {
select {
case <-ctx.Done():
return
default:
conn, err := s.listeners[j].Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}

s.logger.Error("accept failed", zap.Error(err))
continue
}

s.logger.Error("accept failed", zap.Error(err))
continue
s.wg.Run(func() {
util.WithRecovery(func() { s.onConn(ctx, conn, s.addrs[j]) }, nil, s.logger)
})
}

s.wg.Run(func() {
util.WithRecovery(func() { s.onConn(ctx, conn) }, nil, s.logger)
})
}
}
})
})
}
}

func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) {
s.mu.Lock()
conns := uint64(len(s.mu.clients))
maxConns := s.mu.maxConnections
Expand All @@ -149,9 +160,9 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
connID := s.mu.connID
s.mu.connID++
logger := s.logger.With(zap.Uint64("connID", connID), zap.String("client_addr", conn.RemoteAddr().String()),
zap.Bool("proxy-protocol", s.mu.proxyProtocol))
zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.String("addr", addr))
clientConn := client.NewClientConnection(logger.Named("conn"), conn, s.certMgr.ServerTLS(), s.certMgr.SQLTLS(),
s.hsHandler, connID, &backend.BCConfig{
s.hsHandler, connID, addr, &backend.BCConfig{
ProxyProtocol: s.mu.proxyProtocol,
RequireBackendTLS: s.requireBackendTLS,
HealthyKeepAlive: s.mu.healthyKeepAlive,
Expand All @@ -162,7 +173,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
s.mu.Unlock()

logger.Info("new connection")
metrics.ConnGauge.Inc()
metrics.ConnGauge.WithLabelValues(addr).Inc()

defer func() {
s.mu.Lock()
Expand All @@ -174,7 +185,7 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn) {
} else {
logger.Info("connection closed")
}
metrics.ConnGauge.Dec()
metrics.ConnGauge.WithLabelValues(addr).Dec()
}()

if err := keepalive.SetKeepalive(conn, config.KeepAlive{Enabled: tcpKeepAlive}); err != nil {
Expand Down Expand Up @@ -232,8 +243,8 @@ func (s *SQLServer) Close() error {
s.cancelFunc = nil
}
errs := make([]error, 0, 4)
if s.listener != nil {
errs = append(errs, s.listener.Close())
for i := range s.listeners {
errs = append(errs, s.listeners[i].Close())
}

s.mu.RLock()
Expand Down
37 changes: 30 additions & 7 deletions pkg/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package proxy
import (
"context"
"database/sql"
"fmt"
"net"
"strings"
"testing"
Expand Down Expand Up @@ -48,13 +49,13 @@ func TestGracefulShutdown(t *testing.T) {
createClientConn := func() *client.ClientConnection {
server.mu.Lock()
go func() {
conn, err := net.Dial("tcp", server.listener.Addr().String())
conn, err := net.Dial("tcp", server.listeners[0].Addr().String())
require.NoError(t, err)
require.NoError(t, conn.Close())
}()
conn, err := server.listener.Accept()
conn, err := server.listeners[0].Accept()
require.NoError(t, err)
clientConn := client.NewClientConnection(lg, conn, nil, nil, hsHandler, 0, &backend.BCConfig{})
clientConn := client.NewClientConnection(lg, conn, nil, nil, hsHandler, 0, "", &backend.BCConfig{})
server.mu.clients[1] = clientConn
server.mu.Unlock()
return clientConn
Expand Down Expand Up @@ -107,18 +108,40 @@ func TestGracefulShutdown(t *testing.T) {
}
}

func TestRecoverPanic(t *testing.T) {
lg, text := logger.CreateLoggerForTest(t)
func TestMultiAddr(t *testing.T) {
lg, _ := logger.CreateLoggerForTest(t)
certManager := cert.NewCertManager()
err := certManager.Init(&config.Config{}, lg, nil)
require.NoError(t, err)
server, err := NewSQLServer(lg, config.ProxyServer{
Addr: "0.0.0.0:6000",
Addr: "0.0.0.0:0,0.0.0.0:0",
}, certManager, &panicHsHandler{})
require.NoError(t, err)
server.Run(context.Background(), nil)

mdb, err := sql.Open("mysql", "root@tcp(localhost:6000)/test")
require.Len(t, server.listeners, 2)
for _, listener := range server.listeners {
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
require.NoError(t, conn.Close())
}

require.NoError(t, server.Close())
certManager.Close()
}

func TestRecoverPanic(t *testing.T) {
lg, text := logger.CreateLoggerForTest(t)
certManager := cert.NewCertManager()
err := certManager.Init(&config.Config{}, lg, nil)
require.NoError(t, err)
server, err := NewSQLServer(lg, config.ProxyServer{}, certManager, &panicHsHandler{})
require.NoError(t, err)
server.Run(context.Background(), nil)

_, port, err := net.SplitHostPort(server.listeners[0].Addr().String())
require.NoError(t, err)
mdb, err := sql.Open("mysql", fmt.Sprintf("root@tcp(localhost:%s)/test", port))
require.NoError(t, err)
// The first connection encounters panic.
require.ErrorContains(t, mdb.Ping(), "invalid connection")
Expand Down