From e0fba8ac793e73e7e6ac33fcc75acec7355abea1 Mon Sep 17 00:00:00 2001 From: Wendell Sun Date: Mon, 20 Jun 2022 15:09:25 +0800 Subject: [PATCH] Handle access/refresh token without expiration time --- token_store.go | 21 ++++++++++++++++----- token_store_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/token_store.go b/token_store.go index 7ba6eed..a195ad1 100644 --- a/token_store.go +++ b/token_store.go @@ -97,13 +97,14 @@ func (s *TokenStore) gc() { for range s.ticker.C { now := time.Now().Unix() var count int64 - if err := s.db.Table(s.tableName).Where("expired_at <= ?", now).Or("code = ? and access = ? AND refresh = ?", "", "", "").Count(&count).Error; err != nil { + db := s.db.Table(s.tableName).Where("expired_at != 0 AND expired_at <= ?", now).Or("code = ? and access = ? AND refresh = ?", "", "", "") + if err := db.Count(&count).Error; err != nil { s.errorf("[ERROR]:%s\n", err) return } if count > 0 { // not soft delete. - if err := s.db.Table(s.tableName).Where("expired_at <= ?", now).Or("code = ? and access = ? AND refresh = ?", "", "", "").Unscoped().Delete(&TokenStoreItem{}).Error; err != nil { + if err := db.Unscoped().Delete(&TokenStoreItem{}).Error; err != nil { s.errorf("[ERROR]:%s\n", err) } } @@ -125,11 +126,21 @@ func (s *TokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error { item.ExpiredAt = info.GetCodeCreateAt().Add(info.GetCodeExpiresIn()).Unix() } else { item.Access = info.GetAccess() - item.ExpiredAt = info.GetAccessCreateAt().Add(info.GetAccessExpiresIn()).Unix() + if accessExpiresIn := info.GetAccessExpiresIn(); accessExpiresIn != 0 { + item.ExpiredAt = info.GetAccessCreateAt().Add(accessExpiresIn).Unix() + } if refresh := info.GetRefresh(); refresh != "" { - item.Refresh = info.GetRefresh() - item.ExpiredAt = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Unix() + item.Refresh = refresh + refreshExpiresIn := info.GetRefreshExpiresIn() + refreshExpiredAt := info.GetRefreshCreateAt().Add(refreshExpiresIn).Unix() + if item.ExpiredAt != 0 { + if refreshExpiresIn == 0 { + item.ExpiredAt = 0 + } else if refreshExpiredAt > item.ExpiredAt { + item.ExpiredAt = refreshExpiredAt + } + } } } diff --git a/token_store_test.go b/token_store_test.go index 34c6fbf..fad8ad0 100644 --- a/token_store_test.go +++ b/token_store_test.go @@ -72,6 +72,34 @@ func TestTokenStore(t *testing.T) { So(ainfo, ShouldBeNil) }) + Convey("Test access token(no expiration time) store", func() { + info := &models.Token{ + ClientID: "1", + UserID: "1_1", + RedirectURI: "http://localhost/", + Scope: "all", + Access: "1_1_2", + AccessCreateAt: time.Now(), + AccessExpiresIn: 0, + } + err := store.Create(context.Background(), info) + So(err, ShouldBeNil) + + // wait gc + time.Sleep(time.Second) + + ainfo, err := store.GetByAccess(context.Background(), info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo.GetUserID(), ShouldEqual, info.GetUserID()) + + err = store.RemoveByAccess(context.Background(), info.GetAccess()) + So(err, ShouldBeNil) + + ainfo, err = store.GetByAccess(context.Background(), info.GetAccess()) + So(err, ShouldBeNil) + So(ainfo, ShouldBeNil) + }) + Convey("Test refresh token store", func() { info := &models.Token{ ClientID: "1",