Skip to content

Commit 257886e

Browse files
authored
fix: support list of cors origins (#1262)
1 parent ff7c877 commit 257886e

File tree

5 files changed

+125
-40
lines changed

5 files changed

+125
-40
lines changed

cmd/vela-server/main.go

+5
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ func main() {
7575
Usage: "web ui oauth callback path",
7676
Value: "/account/authenticate",
7777
},
78+
&cli.StringSliceFlag{
79+
EnvVars: []string{"VELA_CORS_ALLOW_ORIGINS", "VELA_CORS_ALLOWED_ORIGINS"},
80+
Name: "cors-allow-origins",
81+
Usage: "list of origins a cross-domain request can be executed from",
82+
},
7883
&cli.StringFlag{
7984
EnvVars: []string{"VELA_SECRET"},
8085
Name: "vela-secret",

cmd/vela-server/metadata.go

+4
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ func metadataVela(c *cli.Context) (*internal.Vela, error) {
109109
vela.WebAddress = c.String("webui-addr")
110110
}
111111

112+
if len(c.StringSlice("cors-allow-origins")) > 0 {
113+
vela.CorsAllowOrigins = c.StringSlice("cors-allow-origins")
114+
}
115+
112116
if len(c.String("webui-oauth-callback")) > 0 {
113117
vela.WebOauthCallbackPath = c.String("webui-oauth-callback")
114118
}

internal/metadata.go

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type (
3232
AccessTokenDuration time.Duration `json:"access_token_duration"`
3333
RefreshTokenDuration time.Duration `json:"refresh_token_duration"`
3434
OpenIDIssuer string `json:"oidc_issuer"`
35+
CorsAllowOrigins []string `json:"cors_allow_origins"`
3536
}
3637

3738
// Metadata is the extra set of data passed to the compiler for

router/middleware/header.go

+22-4
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ func Options(c *gin.Context) {
3232
} else {
3333
c.Header("Access-Control-Allow-Origin", "*")
3434

35-
if len(m.Vela.WebAddress) > 0 {
36-
c.Header("Access-Control-Allow-Origin", m.Vela.WebAddress)
35+
origin := CorsAllowOrigin(c, m)
36+
if len(origin) > 0 {
37+
c.Header("Access-Control-Allow-Origin", origin)
3738
c.Header("Access-Control-Allow-Credentials", "true")
3839
}
3940

@@ -65,14 +66,31 @@ func Cors(c *gin.Context) {
6566

6667
c.Header("Access-Control-Allow-Origin", "*")
6768

68-
if len(m.Vela.WebAddress) > 0 {
69-
c.Header("Access-Control-Allow-Origin", m.Vela.WebAddress)
69+
origin := CorsAllowOrigin(c, m)
70+
if len(origin) > 0 {
71+
c.Header("Access-Control-Allow-Origin", origin)
7072
c.Header("Access-Control-Allow-Credentials", "true")
7173
}
7274

7375
c.Header("Access-Control-Expose-Headers", "link, x-total-count")
7476
}
7577

78+
// CorsAllowOrigin is a helper function that returns the
79+
// allowed origin for CORS requests by checking the
80+
// request origin against the allowed origins in the
81+
// Vela metadata.
82+
func CorsAllowOrigin(c *gin.Context, m *internal.Metadata) string {
83+
origin := c.Request.Header.Get("Origin")
84+
for _, domain := range m.Vela.CorsAllowOrigins {
85+
if domain == origin {
86+
return domain
87+
}
88+
}
89+
90+
// return the Vela web address as the default to preserve functionality
91+
return m.Vela.WebAddress
92+
}
93+
7694
// RequestVersion is a middleware function that injects the Vela API version
7795
// information into the request so it will be logged. This is
7896
// intended for debugging and troubleshooting.

router/middleware/header_test.go

+93-36
Original file line numberDiff line numberDiff line change
@@ -176,45 +176,102 @@ func TestMiddleware_Options_InvalidMethod(t *testing.T) {
176176
}
177177

178178
func TestMiddleware_Cors(t *testing.T) {
179-
// setup types
180-
wantOrigin := "*"
181-
wantExposeHeaders := "link, x-total-count"
182-
m := &internal.Metadata{
183-
Vela: &internal.Vela{
184-
Address: "http://localhost:8080",
179+
tests := []struct {
180+
name string
181+
m *internal.Metadata
182+
origin string
183+
expectedOrigin string
184+
expectedCredentials string
185+
expectedExposeHeaders string
186+
}{
187+
{
188+
name: "*",
189+
m: &internal.Metadata{
190+
Vela: &internal.Vela{
191+
Address: "http://localhost:8080",
192+
CorsAllowOrigins: []string{},
193+
},
194+
},
195+
origin: "http://localhost:8888",
196+
expectedOrigin: "*",
197+
expectedCredentials: "",
198+
expectedExposeHeaders: "link, x-total-count",
199+
},
200+
{
201+
name: "WebAddress is origin",
202+
m: &internal.Metadata{
203+
Vela: &internal.Vela{
204+
WebAddress: "http://localhost:8888",
205+
CorsAllowOrigins: []string{},
206+
},
207+
},
208+
origin: "http://localhost:8888",
209+
expectedOrigin: "http://localhost:8888",
210+
expectedCredentials: "true",
211+
expectedExposeHeaders: "link, x-total-count",
212+
},
213+
{
214+
name: "CORSAllowOrigins origin is web address",
215+
m: &internal.Metadata{
216+
Vela: &internal.Vela{
217+
WebAddress: "http://localhost:8888",
218+
CorsAllowOrigins: []string{"http://localhost:3000", "http://localhost:3001"},
219+
},
220+
},
221+
origin: "http://localhost:8888",
222+
expectedOrigin: "http://localhost:8888",
223+
expectedCredentials: "true",
224+
expectedExposeHeaders: "link, x-total-count",
225+
},
226+
{
227+
name: "CORSAllowOrigins origin is in list",
228+
m: &internal.Metadata{
229+
Vela: &internal.Vela{
230+
WebAddress: "",
231+
CorsAllowOrigins: []string{"http://localhost:3000", "http://localhost:3001", "http://localhost:8888"},
232+
},
233+
},
234+
origin: "http://localhost:8888",
235+
expectedOrigin: "http://localhost:8888",
236+
expectedCredentials: "true",
237+
expectedExposeHeaders: "link, x-total-count",
185238
},
186239
}
187240

188-
// setup context
189-
gin.SetMode(gin.TestMode)
190-
191-
resp := httptest.NewRecorder()
192-
context, engine := gin.CreateTestContext(resp)
193-
context.Request, _ = http.NewRequest(http.MethodGet, "/health", nil)
194-
195-
// setup mock server
196-
engine.Use(Metadata(m))
197-
engine.Use(Cors)
198-
engine.GET("/health", func(c *gin.Context) {
199-
c.Status(http.StatusOK)
200-
})
201-
202-
// run test
203-
engine.ServeHTTP(context.Writer, context.Request)
204-
205-
gotOrigin := context.Writer.Header().Get("Access-Control-Allow-Origin")
206-
gotExposeHeaders := context.Writer.Header().Get("Access-Control-Expose-Headers")
207-
208-
if resp.Code != http.StatusOK {
209-
t.Errorf("CORS returned %v, want %v", resp.Code, http.StatusOK)
210-
}
211-
212-
if !reflect.DeepEqual(gotOrigin, wantOrigin) {
213-
t.Errorf("CORS Access-Control-Allow-Origin is %v, want %v", gotOrigin, wantOrigin)
214-
}
215-
216-
if !reflect.DeepEqual(gotExposeHeaders, wantExposeHeaders) {
217-
t.Errorf("CORS Access-Control-Expose-Headers is %v, want %v", gotExposeHeaders, wantExposeHeaders)
241+
for _, tt := range tests {
242+
t.Run(tt.name, func(t *testing.T) {
243+
gin.SetMode(gin.TestMode)
244+
resp := httptest.NewRecorder()
245+
context, engine := gin.CreateTestContext(resp)
246+
context.Request, _ = http.NewRequest(http.MethodGet, "/health", nil)
247+
context.Request.Header.Add("Origin", tt.origin)
248+
249+
// inject metadata
250+
engine.Use(func(c *gin.Context) {
251+
c.Set("metadata", tt.m)
252+
c.Next()
253+
})
254+
engine.Use(Cors)
255+
engine.GET("/health", func(c *gin.Context) {
256+
c.Status(http.StatusOK)
257+
})
258+
engine.ServeHTTP(context.Writer, context.Request)
259+
260+
gotOrigin := context.Writer.Header().Get("Access-Control-Allow-Origin")
261+
if gotOrigin != tt.expectedOrigin {
262+
t.Errorf("Access-Control-Allow-Origin is %v; want %v", gotOrigin, tt.expectedOrigin)
263+
}
264+
265+
gotCredentials := context.Writer.Header().Get("Access-Control-Allow-Credentials")
266+
if gotCredentials != tt.expectedCredentials {
267+
t.Errorf("Access-Control-Allow-Credentials is %v; want %v", gotCredentials, tt.expectedCredentials)
268+
}
269+
270+
gotExposeHeaders := context.Writer.Header().Get("Access-Control-Expose-Headers")
271+
if gotExposeHeaders != tt.expectedExposeHeaders {
272+
t.Errorf("Access-Control-Expose-Headers is %v; want %v", gotExposeHeaders, tt.expectedExposeHeaders)
273+
}
274+
})
218275
}
219276
}
220277

0 commit comments

Comments
 (0)