Skip to content

Commit 9bc1824

Browse files
committed
Support lastMessage() and rowsAffected()
1 parent f20b286 commit 9bc1824

12 files changed

+106
-77
lines changed

auth.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,15 @@ func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte,
225225
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
226226
}
227227

228-
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
228+
func (mc *MysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
229229
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
230230
if err != nil {
231231
return err
232232
}
233233
return mc.writeAuthSwitchPacket(enc)
234234
}
235235

236-
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
236+
func (mc *MysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
237237
switch plugin {
238238
case "caching_sha2_password":
239239
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
@@ -296,7 +296,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
296296
}
297297
}
298298

299-
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
299+
func (mc *MysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
300300
// Read Result Packet
301301
authData, newPlugin, err := mc.readAuthResult()
302302
if err != nil {

benchmark_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ func BenchmarkRoundtripBin(b *testing.B) {
215215
}
216216

217217
func BenchmarkInterpolation(b *testing.B) {
218-
mc := &mysqlConn{
218+
mc := &MysqlConn{
219219
cfg: &Config{
220220
InterpolateParams: true,
221221
Loc: time.UTC,

connection.go

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ import (
2020
"time"
2121
)
2222

23-
type mysqlConn struct {
23+
type MysqlConn struct {
2424
buf buffer
2525
netConn net.Conn
2626
rawConn net.Conn // underlying connection when netConn is TLS connection.
2727
affectedRows uint64
2828
insertId uint64
29+
lastMessage string
2930
cfg *Config
3031
maxAllowedPacket int
3132
maxWriteSize int
@@ -45,8 +46,18 @@ type mysqlConn struct {
4546
closed atomicBool // set when conn is closed, before closech is closed
4647
}
4748

49+
// RowsAffected returns the number of rows affected by the query.
50+
func (mc *MysqlConn) RowsAffected() uint64 {
51+
return mc.affectedRows
52+
}
53+
54+
// LastMessage returns the database's last message.
55+
func (mc *MysqlConn) LastMessage() string {
56+
return mc.lastMessage
57+
}
58+
4859
// Handles parameters set in DSN after the connection is established
49-
func (mc *mysqlConn) handleParams() (err error) {
60+
func (mc *MysqlConn) handleParams() (err error) {
5061
var cmdSet strings.Builder
5162
for param, val := range mc.cfg.Params {
5263
switch param {
@@ -89,7 +100,7 @@ func (mc *mysqlConn) handleParams() (err error) {
89100
return
90101
}
91102

92-
func (mc *mysqlConn) markBadConn(err error) error {
103+
func (mc *MysqlConn) markBadConn(err error) error {
93104
if mc == nil {
94105
return err
95106
}
@@ -99,11 +110,11 @@ func (mc *mysqlConn) markBadConn(err error) error {
99110
return driver.ErrBadConn
100111
}
101112

102-
func (mc *mysqlConn) Begin() (driver.Tx, error) {
113+
func (mc *MysqlConn) Begin() (driver.Tx, error) {
103114
return mc.begin(false)
104115
}
105116

106-
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
117+
func (mc *MysqlConn) begin(readOnly bool) (driver.Tx, error) {
107118
if mc.closed.Load() {
108119
errLog.Print(ErrInvalidConn)
109120
return nil, driver.ErrBadConn
@@ -121,7 +132,7 @@ func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
121132
return nil, mc.markBadConn(err)
122133
}
123134

124-
func (mc *mysqlConn) Close() (err error) {
135+
func (mc *MysqlConn) Close() (err error) {
125136
// Makes Close idempotent
126137
if !mc.closed.Load() {
127138
err = mc.writeCommandPacket(comQuit)
@@ -136,7 +147,7 @@ func (mc *mysqlConn) Close() (err error) {
136147
// function after successfully authentication, call Close instead. This function
137148
// is called before auth or on auth failure because MySQL will have already
138149
// closed the network connection.
139-
func (mc *mysqlConn) cleanup() {
150+
func (mc *MysqlConn) cleanup() {
140151
if mc.closed.Swap(true) {
141152
return
142153
}
@@ -151,7 +162,7 @@ func (mc *mysqlConn) cleanup() {
151162
}
152163
}
153164

154-
func (mc *mysqlConn) error() error {
165+
func (mc *MysqlConn) error() error {
155166
if mc.closed.Load() {
156167
if err := mc.canceled.Value(); err != nil {
157168
return err
@@ -161,7 +172,7 @@ func (mc *mysqlConn) error() error {
161172
return nil
162173
}
163174

164-
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
175+
func (mc *MysqlConn) Prepare(query string) (driver.Stmt, error) {
165176
if mc.closed.Load() {
166177
errLog.Print(ErrInvalidConn)
167178
return nil, driver.ErrBadConn
@@ -195,7 +206,7 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
195206
return stmt, err
196207
}
197208

198-
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
209+
func (mc *MysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
199210
// Number of ? should be same to len(args)
200211
if strings.Count(query, "?") != len(args) {
201212
return "", driver.ErrSkip
@@ -294,7 +305,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
294305
return string(buf), nil
295306
}
296307

297-
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
308+
func (mc *MysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
298309
if mc.closed.Load() {
299310
errLog.Print(ErrInvalidConn)
300311
return nil, driver.ErrBadConn
@@ -312,6 +323,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
312323
}
313324
mc.affectedRows = 0
314325
mc.insertId = 0
326+
mc.lastMessage = ""
315327

316328
err := mc.exec(query)
317329
if err == nil {
@@ -324,7 +336,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
324336
}
325337

326338
// Internal function to execute commands
327-
func (mc *mysqlConn) exec(query string) error {
339+
func (mc *MysqlConn) exec(query string) error {
328340
// Send command
329341
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
330342
return mc.markBadConn(err)
@@ -351,11 +363,11 @@ func (mc *mysqlConn) exec(query string) error {
351363
return mc.discardResults()
352364
}
353365

354-
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
366+
func (mc *MysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
355367
return mc.query(query, args)
356368
}
357369

358-
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
370+
func (mc *MysqlConn) query(query string, args []driver.Value) (*textRows, error) {
359371
if mc.closed.Load() {
360372
errLog.Print(ErrInvalidConn)
361373
return nil, driver.ErrBadConn
@@ -371,6 +383,11 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
371383
}
372384
query = prepared
373385
}
386+
387+
mc.affectedRows = 0
388+
mc.insertId = 0
389+
mc.lastMessage = ""
390+
374391
// Send command
375392
err := mc.writeCommandPacketStr(comQuery, query)
376393
if err == nil {
@@ -402,7 +419,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
402419

403420
// Gets the value of the given MySQL System Variable
404421
// The returned byte slice is only valid until the next read
405-
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
422+
func (mc *MysqlConn) getSystemVar(name string) ([]byte, error) {
406423
// Send command
407424
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
408425
return nil, err
@@ -431,13 +448,13 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
431448
}
432449

433450
// finish is called when the query has canceled.
434-
func (mc *mysqlConn) cancel(err error) {
451+
func (mc *MysqlConn) cancel(err error) {
435452
mc.canceled.Set(err)
436453
mc.cleanup()
437454
}
438455

439456
// finish is called when the query has succeeded.
440-
func (mc *mysqlConn) finish() {
457+
func (mc *MysqlConn) finish() {
441458
if !mc.watching || mc.finished == nil {
442459
return
443460
}
@@ -449,7 +466,7 @@ func (mc *mysqlConn) finish() {
449466
}
450467

451468
// Ping implements driver.Pinger interface
452-
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
469+
func (mc *MysqlConn) Ping(ctx context.Context) (err error) {
453470
if mc.closed.Load() {
454471
errLog.Print(ErrInvalidConn)
455472
return driver.ErrBadConn
@@ -468,7 +485,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
468485
}
469486

470487
// BeginTx implements driver.ConnBeginTx interface
471-
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
488+
func (mc *MysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
472489
if mc.closed.Load() {
473490
return nil, driver.ErrBadConn
474491
}
@@ -492,7 +509,7 @@ func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver
492509
return mc.begin(opts.ReadOnly)
493510
}
494511

495-
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
512+
func (mc *MysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
496513
dargs, err := namedValueToValue(args)
497514
if err != nil {
498515
return nil, err
@@ -511,7 +528,7 @@ func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driv
511528
return rows, err
512529
}
513530

514-
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
531+
func (mc *MysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
515532
dargs, err := namedValueToValue(args)
516533
if err != nil {
517534
return nil, err
@@ -525,7 +542,7 @@ func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []drive
525542
return mc.Exec(query, dargs)
526543
}
527544

528-
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
545+
func (mc *MysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
529546
if err := mc.watchCancel(ctx); err != nil {
530547
return nil, err
531548
}
@@ -578,7 +595,7 @@ func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue
578595
return stmt.Exec(dargs)
579596
}
580597

581-
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
598+
func (mc *MysqlConn) watchCancel(ctx context.Context) error {
582599
if mc.watching {
583600
// Reach here if canceled,
584601
// so the connection is already invalid
@@ -603,7 +620,7 @@ func (mc *mysqlConn) watchCancel(ctx context.Context) error {
603620
return nil
604621
}
605622

606-
func (mc *mysqlConn) startWatcher() {
623+
func (mc *MysqlConn) startWatcher() {
607624
watcher := make(chan context.Context, 1)
608625
mc.watcher = watcher
609626
finished := make(chan struct{})
@@ -628,14 +645,14 @@ func (mc *mysqlConn) startWatcher() {
628645
}()
629646
}
630647

631-
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
648+
func (mc *MysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
632649
nv.Value, err = converter{}.ConvertValue(nv.Value)
633650
return
634651
}
635652

636653
// ResetSession implements driver.SessionResetter.
637654
// (From Go 1.10)
638-
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
655+
func (mc *MysqlConn) ResetSession(ctx context.Context) error {
639656
if mc.closed.Load() {
640657
return driver.ErrBadConn
641658
}
@@ -645,6 +662,6 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
645662

646663
// IsValid implements driver.Validator interface
647664
// (From Go 1.15)
648-
func (mc *mysqlConn) IsValid() bool {
665+
func (mc *MysqlConn) IsValid() bool {
649666
return !mc.closed.Load()
650667
}

connection_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
)
1919

2020
func TestInterpolateParams(t *testing.T) {
21-
mc := &mysqlConn{
21+
mc := &MysqlConn{
2222
buf: newBuffer(nil),
2323
maxAllowedPacket: maxPacketSize,
2424
cfg: &Config{
@@ -38,7 +38,7 @@ func TestInterpolateParams(t *testing.T) {
3838
}
3939

4040
func TestInterpolateParamsJSONRawMessage(t *testing.T) {
41-
mc := &mysqlConn{
41+
mc := &MysqlConn{
4242
buf: newBuffer(nil),
4343
maxAllowedPacket: maxPacketSize,
4444
cfg: &Config{
@@ -65,7 +65,7 @@ func TestInterpolateParamsJSONRawMessage(t *testing.T) {
6565
}
6666

6767
func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
68-
mc := &mysqlConn{
68+
mc := &MysqlConn{
6969
buf: newBuffer(nil),
7070
maxAllowedPacket: maxPacketSize,
7171
cfg: &Config{
@@ -82,7 +82,7 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
8282
// We don't support placeholder in string literal for now.
8383
// https://github.com/go-sql-driver/mysql/pull/490
8484
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
85-
mc := &mysqlConn{
85+
mc := &MysqlConn{
8686
buf: newBuffer(nil),
8787
maxAllowedPacket: maxPacketSize,
8888
cfg: &Config{
@@ -98,7 +98,7 @@ func TestInterpolateParamsPlaceholderInString(t *testing.T) {
9898
}
9999

100100
func TestInterpolateParamsUint64(t *testing.T) {
101-
mc := &mysqlConn{
101+
mc := &MysqlConn{
102102
buf: newBuffer(nil),
103103
maxAllowedPacket: maxPacketSize,
104104
cfg: &Config{
@@ -117,7 +117,7 @@ func TestInterpolateParamsUint64(t *testing.T) {
117117

118118
func TestCheckNamedValue(t *testing.T) {
119119
value := driver.NamedValue{Value: ^uint64(0)}
120-
x := &mysqlConn{}
120+
x := &MysqlConn{}
121121
err := x.CheckNamedValue(&value)
122122

123123
if err != nil {
@@ -132,7 +132,7 @@ func TestCheckNamedValue(t *testing.T) {
132132
// TestCleanCancel tests passed context is cancelled at start.
133133
// No packet should be sent. Connection should keep current status.
134134
func TestCleanCancel(t *testing.T) {
135-
mc := &mysqlConn{
135+
mc := &MysqlConn{
136136
closech: make(chan struct{}),
137137
}
138138
mc.startWatcher()
@@ -159,7 +159,7 @@ func TestCleanCancel(t *testing.T) {
159159

160160
func TestPingMarkBadConnection(t *testing.T) {
161161
nc := badConnection{err: errors.New("boom")}
162-
ms := &mysqlConn{
162+
ms := &MysqlConn{
163163
netConn: nc,
164164
buf: newBuffer(nc),
165165
maxAllowedPacket: defaultMaxAllowedPacket,
@@ -174,7 +174,7 @@ func TestPingMarkBadConnection(t *testing.T) {
174174

175175
func TestPingErrInvalidConn(t *testing.T) {
176176
nc := badConnection{err: errors.New("failed to write"), n: 10}
177-
ms := &mysqlConn{
177+
ms := &MysqlConn{
178178
netConn: nc,
179179
buf: newBuffer(nc),
180180
maxAllowedPacket: defaultMaxAllowedPacket,

connector.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
2424
var err error
2525

2626
// New mysqlConn
27-
mc := &mysqlConn{
27+
mc := &MysqlConn{
2828
maxAllowedPacket: maxPacketSize,
2929
maxWriteSize: maxPacketSize - 1,
3030
closech: make(chan struct{}),

0 commit comments

Comments
 (0)