diff --git a/cmd/sqlcmd/main.go b/cmd/sqlcmd/main.go index d39c839e..44684140 100644 --- a/cmd/sqlcmd/main.go +++ b/cmd/sqlcmd/main.go @@ -4,6 +4,7 @@ package main import ( + "errors" "fmt" "os" @@ -259,7 +260,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { if args.ErrorsToStderr >= 0 { s.PrintError = func(msg string, severity uint8) bool { if severity >= stderrSeverity { - _, _ = os.Stderr.Write([]byte(msg + sqlcmd.SqlcmdEol)) + s.WriteError(os.Stderr, errors.New(msg+sqlcmd.SqlcmdEol)) return true } return false @@ -285,7 +286,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } else { for f := range args.InputFile { if err = s.IncludeFile(args.InputFile[f], true); err != nil { - _, _ = os.Stderr.Write([]byte(err.Error() + sqlcmd.SqlcmdEol)) + s.WriteError(s.GetError(), err) s.Exitcode = 1 break } diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 9deb3d09..d3562e7f 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -116,12 +116,12 @@ func (c Commands) matchCommand(line string) (*Command, []string) { } func warnDisabled(s *Sqlcmd, args []string, line uint) error { - _, _ = s.GetError().Write([]byte(ErrCommandsDisabled.Error() + SqlcmdEol)) + s.WriteError(s.GetError(), ErrCommandsDisabled) return nil } func errorDisabled(s *Sqlcmd, args []string, line uint) error { - _, _ = s.GetError().Write([]byte(ErrCommandsDisabled.Error() + SqlcmdEol)) + s.WriteError(s.GetError(), ErrCommandsDisabled) s.Exitcode = 1 return ErrExitRequested } @@ -433,7 +433,7 @@ func resolveArgumentVariables(s *Sqlcmd, arg []rune, failOnUnresolved bool) (str if failOnUnresolved { return "", UndefinedVariable(varName) } - _, _ = s.GetError().Write([]byte(UndefinedVariable(varName).Error() + SqlcmdEol)) + s.WriteError(s.GetError(), UndefinedVariable(varName)) if b != nil { b.WriteString(string(arg[i : vl+1])) } diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 1bd95d81..0f8be3a2 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -136,7 +136,7 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { args = make([]string, 0) once = true } else { - _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + s.WriteError(s.GetOutput(), err) } } if cmd != nil { @@ -146,7 +146,7 @@ func (s *Sqlcmd) Run(once bool, processAll bool) error { break } if err != nil { - _, _ = s.GetOutput().Write([]byte(err.Error() + SqlcmdEol)) + s.WriteError(s.GetOutput(), err) lastError = err } } @@ -209,6 +209,19 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) { s.err = e } +// WriteError writes the error on specified stream +func (s *Sqlcmd) WriteError(stream io.Writer, err error) { + if strings.HasPrefix(err.Error(), ErrorPrefix) { + if s.GetError() != os.Stdout { + _, _ = s.GetError().Write([]byte(err.Error() + SqlcmdEol)) + } else { + _, _ = os.Stderr.Write([]byte(err.Error() + SqlcmdEol)) + } + } else { + _, _ = stream.Write([]byte(err.Error() + SqlcmdEol)) + } +} + // ConnectDb opens a connection to the database with the given modifications to the connection // nopw == true means don't prompt for a password if the auth type requires it // if connect is nil, ConnectDb uses the current connection. If non-nil and the connection succeeds, @@ -364,7 +377,7 @@ func setupCloseHandler(s *Sqlcmd) { signal.Notify(c, os.Interrupt, syscall.SIGTERM) go func() { <-c - _, _ = s.GetOutput().Write([]byte(ErrCtrlC.Error() + SqlcmdEol)) + s.WriteError(s.GetOutput(), ErrCtrlC) os.Exit(0) }() } diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index 5dbb26d5..0ae6605e 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -462,6 +462,44 @@ func TestQueryServerPropertyReturnsColumnName(t *testing.T) { } } +func TestSqlCmdOutputAndError(t *testing.T) { + s, outfile, errfile := setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + s.Query = "select $(X" + err := s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 1."+SqlcmdEol, string(bytes), "Expected syntax error not received for query execution") + } + } + s.Query = "select '1'" + err = s.Run(true, false) + if assert.NoError(t, err, "s.Run(once = true)") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for query execution") + } + } + + s, outfile, errfile = setupSqlcmdWithFileErrorOutput(t) + defer os.Remove(outfile.Name()) + defer os.Remove(errfile.Name()) + dataPath := "testdata" + string(os.PathSeparator) + err = s.IncludeFile(dataPath+"testerrorredirection.sql", false) + if assert.NoError(t, err, "IncludeFile testerrorredirection.sql false") { + bytes, err := os.ReadFile(outfile.Name()) + if assert.NoError(t, err, "os.ReadFile outfile") { + assert.Equal(t, "1"+SqlcmdEol+SqlcmdEol+"(1 row affected)"+SqlcmdEol, string(bytes), "Unexpected output for sql file execution in outfile") + } + bytes, err = os.ReadFile(errfile.Name()) + if assert.NoError(t, err, "os.ReadFile errfile") { + assert.Equal(t, "Sqlcmd: Error: Syntax error at line 3."+SqlcmdEol, string(bytes), "Expected syntax error not found in errfile") + } + } +} + // runSqlCmd uses lines as input for sqlcmd instead of relying on file or console input func runSqlCmd(t testing.TB, s *Sqlcmd, lines []string) error { t.Helper() @@ -509,6 +547,28 @@ func setupSqlcmdWithFileOutput(t testing.TB) (*Sqlcmd, *os.File) { return s, file } +func setupSqlcmdWithFileErrorOutput(t testing.TB) (*Sqlcmd, *os.File, *os.File) { + t.Helper() + v := InitializeVariables(true) + v.Set(SQLCMDMAXVARTYPEWIDTH, "0") + s := New(nil, "", v) + s.Connect = newConnect(t) + s.Format = NewSQLCmdDefaultFormatter(true) + outfile, err := os.CreateTemp("", "sqlcmdout") + assert.NoError(t, err, "os.CreateTemp") + errfile, err := os.CreateTemp("", "sqlcmderr") + assert.NoError(t, err, "os.CreateTemp") + s.SetOutput(outfile) + s.SetError(errfile) + err = s.ConnectDb(nil, true) + if err != nil { + os.Remove(outfile.Name()) + os.Remove(errfile.Name()) + } + assert.NoError(t, err, "s.ConnectDB") + return s, outfile, errfile +} + // Assuming public Azure, use AAD when SQLCMDUSER environment variable is not set func canTestAzureAuth() bool { server := os.Getenv(SQLCMDSERVER) diff --git a/pkg/sqlcmd/testdata/testerrorredirection.sql b/pkg/sqlcmd/testdata/testerrorredirection.sql new file mode 100644 index 00000000..ef8ff9f2 --- /dev/null +++ b/pkg/sqlcmd/testdata/testerrorredirection.sql @@ -0,0 +1,4 @@ +select '1' +go +select $(var +go \ No newline at end of file