diff --git a/src/main.go b/src/main.go index 9ddaf45..adfe4d0 100644 --- a/src/main.go +++ b/src/main.go @@ -302,7 +302,7 @@ func (t *tester) postProcess() { t.mdb.Close() } -func (t *tester) addFailure(testSuite *XUnitTestSuite, err *error, cnt int) { +func (t *tester) addFailure(err *error, cnt int) { testSuite.TestCases = append(testSuite.TestCases, XUnitTestCase{ Classname: "", Name: t.testFileName(), @@ -313,7 +313,7 @@ func (t *tester) addFailure(testSuite *XUnitTestSuite, err *error, cnt int) { testSuite.Failures++ } -func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt int) { +func (t *tester) addSuccess(startTime *time.Time, cnt int) { testSuite.TestCases = append(testSuite.TestCases, XUnitTestCase{ Classname: "", Name: t.testFileName(), @@ -328,13 +328,13 @@ func (t *tester) Run() error { queries, err := t.loadQueries() if err != nil { err = errors.Trace(err) - t.addFailure(&testSuite, &err, 0) + t.addFailure(&err, 0) return err } if err = t.openResult(); err != nil { err = errors.Trace(err) - t.addFailure(&testSuite, &err, 0) + t.addFailure(&err, 0) return err } @@ -378,7 +378,7 @@ func (t *tester) Run() error { concurrentSize, err = strconv.Atoi(strings.TrimSpace(s)) if err != nil { err = errors.Annotate(err, "Atoi failed") - t.addFailure(&testSuite, &err, testCnt) + t.addFailure(&err, testCnt) return err } } @@ -386,7 +386,7 @@ func (t *tester) Run() error { t.enableConcurrent = false if err = t.concurrentRun(concurrentQueue, concurrentSize); err != nil { err = errors.Annotate(err, fmt.Sprintf("concurrent test failed in %v", t.name)) - t.addFailure(&testSuite, &err, testCnt) + t.addFailure(&err, testCnt) return err } t.expectedErrs = nil @@ -405,7 +405,7 @@ func (t *tester) Run() error { concurrentQueue = append(concurrentQueue, q) } else if err = t.execute(q); err != nil { err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query)) - t.addFailure(&testSuite, &err, testCnt) + t.addFailure(&err, testCnt) return err } @@ -425,7 +425,7 @@ func (t *tester) Run() error { colNr, err := strconv.Atoi(cols[i]) if err != nil { err = errors.Annotate(err, fmt.Sprintf("Could not parse column in --replace_column: sql:%v", q.Query)) - t.addFailure(&testSuite, &err, testCnt) + t.addFailure(&err, testCnt) return err } @@ -501,7 +501,7 @@ func (t *tester) Run() error { fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) if xmlPath != "" { - t.addSuccess(&testSuite, &startTime, testCnt) + t.addSuccess(&startTime, testCnt) } return t.flushResult() @@ -1141,21 +1141,7 @@ func consumeError() []error { } } -func main() { - flag.Parse() - tests := flag.Args() - startTime := time.Now() - if ll := os.Getenv("LOG_LEVEL"); ll != "" { - logLevel = ll - } - if logLevel != "" { - ll, err := log.ParseLevel(logLevel) - if err != nil { - log.Errorf("error parsing log level %s: %v", logLevel, err) - } - log.SetLevel(ll) - } - +func writeXUnitFile(startTime time.Time) { if xmlPath != "" { _, err := os.Stat(xmlPath) if err == nil { @@ -1176,29 +1162,36 @@ func main() { log.Error("open xunit file fail:", err) os.Exit(1) } - - testSuite = XUnitTestSuite{ - Name: "", - Tests: 0, - Failures: 0, - Properties: make([]XUnitProperty, 0), - TestCases: make([]XUnitTestCase, 0), + testSuite.Tests = len(testSuite.TestCases) + testSuite.Time = fmt.Sprintf("%fs", time.Since(startTime).Seconds()) + testSuite.Properties = append(testSuite.Properties, XUnitProperty{ + Name: "go.version", + Value: goVersion(), + }) + err = Write(xmlFile, testSuite) + if err != nil { + log.Error("Write xunit file fail:", err) } + err = xmlFile.Close() + if err != nil { + log.Error("Close xunit file fail:", err) + } + } +} - defer func() { - if xmlFile != nil { - testSuite.Tests = len(tests) - testSuite.Time = fmt.Sprintf("%fs", time.Since(startTime).Seconds()) - testSuite.Properties = append(testSuite.Properties, XUnitProperty{ - Name: "go.version", - Value: goVersion(), - }) - err := Write(xmlFile, testSuite) - if err != nil { - log.Error("Write xunit file fail:", err) - } - } - }() +func main() { + flag.Parse() + tests := flag.Args() + startTime := time.Now() + if ll := os.Getenv("LOG_LEVEL"); ll != "" { + logLevel = ll + } + if logLevel != "" { + ll, err := log.ParseLevel(logLevel) + if err != nil { + log.Errorf("error parsing log level %s: %v", logLevel, err) + } + log.SetLevel(ll) } // we will run all tests if no tests assigned @@ -1209,6 +1202,15 @@ func main() { } } + if xmlPath != "" { + testSuite = XUnitTestSuite{ + Name: "", + Tests: 0, + Failures: 0, + Properties: make([]XUnitProperty, 0), + TestCases: make([]XUnitTestCase, 0), + } + } if !record { log.Infof("running tests: %v", tests) } else { @@ -1225,6 +1227,7 @@ func main() { es := consumeError() println() + writeXUnitFile(startTime) if len(es) != 0 { log.Errorf("%d tests failed\n", len(es)) for _, item := range es {