diff --git a/sqlstore.go b/sqlstore.go index b16b0a74d..5a4a5d690 100644 --- a/sqlstore.go +++ b/sqlstore.go @@ -3,6 +3,7 @@ package quickfix import ( "database/sql" "fmt" + "regexp" "time" "github.com/quickfixgo/quickfix/config" @@ -19,6 +20,27 @@ type sqlStore struct { sqlDataSourceName string sqlConnMaxLifetime time.Duration db *sql.DB + placeholder placeholderFunc +} + +type placeholderFunc func(int) string + +var rePlaceholder = regexp.MustCompile(`\?`) + +func sqlString(raw string, placeholder placeholderFunc) string { + if placeholder == nil { + return raw + } + idx := 0 + return rePlaceholder.ReplaceAllStringFunc(raw, func(s string) string { + new := placeholder(idx) + idx += 1 + return new + }) +} + +func postgresPlaceholder(i int) string { + return fmt.Sprintf("$%d", i+1) } // NewSQLStoreFactory returns a sql-based implementation of MessageStoreFactory @@ -60,6 +82,10 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn } store.cache.Reset() + if store.sqlDriver == "postgres" { + store.placeholder = postgresPlaceholder + } + if store.db, err = sql.Open(store.sqlDriver, store.sqlDataSourceName); err != nil { return nil, err } @@ -78,10 +104,10 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn // Reset deletes the store records and sets the seqnums back to 1 func (store *sqlStore) Reset() error { s := store.sessionID - _, err := store.db.Exec(`DELETE FROM messages + _, err := store.db.Exec(sqlString(`DELETE FROM messages WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -93,11 +119,11 @@ func (store *sqlStore) Reset() error { return err } - _, err = store.db.Exec(`UPDATE sessions + _, err = store.db.Exec(sqlString(`UPDATE sessions SET creation_time=?, incoming_seqnum=?, outgoing_seqnum=? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), store.cache.CreationTime(), store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, @@ -118,11 +144,11 @@ func (store *sqlStore) populateCache() (err error) { s := store.sessionID var creationTime time.Time var incomingSeqNum, outgoingSeqNum int - row := store.db.QueryRow(`SELECT creation_time, incoming_seqnum, outgoing_seqnum + row := store.db.QueryRow(sqlString(`SELECT creation_time, incoming_seqnum, outgoing_seqnum FROM sessions WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -143,12 +169,12 @@ func (store *sqlStore) populateCache() (err error) { } // session record not found, create it - _, err = store.db.Exec(`INSERT INTO sessions ( + _, err = store.db.Exec(sqlString(`INSERT INTO sessions ( creation_time, incoming_seqnum, outgoing_seqnum, beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, store.placeholder), store.cache.creationTime, store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(), @@ -172,10 +198,10 @@ func (store *sqlStore) NextTargetMsgSeqNum() int { // SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error { s := store.sessionID - _, err := store.db.Exec(`UPDATE sessions SET outgoing_seqnum = ? + _, err := store.db.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), next, s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -188,10 +214,10 @@ func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error { // SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error { s := store.sessionID - _, err := store.db.Exec(`UPDATE sessions SET incoming_seqnum = ? + _, err := store.db.Exec(sqlString(`UPDATE sessions SET incoming_seqnum = ? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), next, s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -221,12 +247,12 @@ func (store *sqlStore) CreationTime() time.Time { func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { s := store.sessionID - _, err := store.db.Exec(`INSERT INTO messages ( + _, err := store.db.Exec(sqlString(`INSERT INTO messages ( msgseqnum, message, beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, store.placeholder), seqNum, string(msg), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, @@ -238,12 +264,12 @@ func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { s := store.sessionID var msgs [][]byte - rows, err := store.db.Query(`SELECT message FROM messages + rows, err := store.db.Query(sqlString(`SELECT message FROM messages WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? AND targetcompid=? AND targetsubid=? AND targetlocid=? AND msgseqnum>=? AND msgseqnum<=? - ORDER BY msgseqnum`, + ORDER BY msgseqnum`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID, diff --git a/sqlstore_test.go b/sqlstore_test.go index b252b2c25..da2c8a5a4 100644 --- a/sqlstore_test.go +++ b/sqlstore_test.go @@ -60,6 +60,11 @@ TargetCompID=%s`, sqlDriver, sqlDsn, sessionID.BeginString, sessionID.SenderComp require.Nil(suite.T(), err) } +func (suite *SQLStoreTestSuite) TestSqlPlaceholderReplacement() { + got := sqlString("A ? B ? C ?", postgresPlaceholder) + suite.Equal("A $1 B $2 C $3", got) +} + func (suite *SQLStoreTestSuite) TearDownTest() { suite.msgStore.Close() os.RemoveAll(suite.sqlStoreRootPath)