Skip to content

Commit c3aef7d

Browse files
namoscatonrutman
andauthored
[sc-23218] Generalize middleware (#5)
Co-authored-by: Nate Rutman <nrutman@users.noreply.github.com>
1 parent c5b1ca3 commit c3aef7d

File tree

7 files changed

+121
-134
lines changed

7 files changed

+121
-134
lines changed

README.md

+3-7
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,9 @@ The **Results Validator** ensures that the pipeline has fulfilled the interface
8686

8787
## Middleware
8888

89-
If **Middleware** is specified, it will be run on the specified stage lifecycle event(s) for each stage in the pipeline.
89+
If **Middleware** is specified, it will be wrapped around each stage in the pipeline. This follows [a pattern similar to Express](https://expressjs.com/en/guide/using-middleware.html). Each middleware is called in the order it is specified and includes a `next()` to call the next middleware/stage.
9090

91-
| Stage Event | Description |
92-
| ----------------- | ---------------------------------- |
93-
| `onStageStart` | Runs before each stage is executed |
94-
| `onStageComplete` | Runs after each stage is executed |
95-
96-
Middleware is specified as an object with middleware callbacks mapped to at least one of the above event keys. A middleware callback is provided the following attributes:
91+
A middleware callback is provided the following attributes:
9792

9893
| Parameter | Description |
9994
| -------------- | ----------------------------------------------------------------------------- |
@@ -102,6 +97,7 @@ Middleware is specified as an object with middleware callbacks mapped to at leas
10297
| `results` | A read-only set of results returned by stages so far |
10398
| `stageNames` | An array of the names of the methods that make up the current pipeline stages |
10499
| `currentStage` | The name of the current pipeline stage |
100+
| `next` | Calls the next middleware in the stack (or the stage if none) |
105101

106102
See the [LogStageMiddlewareFactory](./src/middleware/logStageMiddlewareFactory.ts) for a simple middleware implementation. It is wrapped in a factory method so a log method can be properly injected.
107103

package.json

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"main": "build/index.js",
66
"types": "build/index.d.ts",
77
"scripts": {
8+
"dev": "tsc --noEmit --watch",
89
"prepack": "npm run build",
910
"build": "tsc --project tsconfig.build.json",
1011
"eslint": "eslint --ext .js,.ts --cache --cache-location=node_modules/.cache/eslint --cache-strategy content .",

src/__mocks__/TestPipeline.ts

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { last } from "lodash";
2-
import {
2+
import type {
33
PipelineInitializer,
4+
PipelineMiddleware,
45
PipelineResultValidator,
56
PipelineStage,
67
} from "../types";
@@ -24,6 +25,12 @@ export type TestStage = PipelineStage<
2425
TestPipelineResults
2526
>;
2627

28+
export type TestMiddleware = PipelineMiddleware<
29+
TestPipelineArguments,
30+
TestPipelineContext,
31+
TestPipelineResults
32+
>;
33+
2734
/**
2835
* A stage to set up the test pipeline
2936
*/
@@ -82,7 +89,7 @@ export const errorStage: TestStage = () => {
8289
*/
8390
export const testPipelineResultValidator: PipelineResultValidator<
8491
TestPipelineResults
85-
> = (results) => {
92+
> = (results): results is TestPipelineResults => {
8693
// false if sum is not a number
8794
if (typeof results.sum !== "number") {
8895
return false;

src/__tests__/buildPipeline.test.ts

+45-26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import { logStageMiddlewareFactory } from "middleware/logStageMiddlewareFactory";
12
import {
3+
TestMiddleware,
24
TestPipelineArguments,
35
TestPipelineContext,
46
TestPipelineResults,
@@ -12,7 +14,6 @@ import {
1214
import { buildPipeline } from "../buildPipeline";
1315
import { PipelineError } from "../error/PipelineError";
1416
import { returnSumResult } from "./../__mocks__/TestPipeline";
15-
import { PipelineMiddleware } from "./../types";
1617

1718
const INCREMENT = 5;
1819

@@ -53,42 +54,55 @@ describe("buildPipeline", () => {
5354
});
5455

5556
describe("when using middleware", () => {
56-
const testStart = jest.fn();
57-
const testComplete = jest.fn();
58-
const testMiddleware: PipelineMiddleware = {
59-
onStageStart: testStart,
60-
onStageComplete: testComplete,
61-
};
62-
63-
const partialComplete = jest.fn();
64-
const partialMiddleware: PipelineMiddleware = {
65-
onStageComplete: partialComplete,
66-
};
67-
68-
beforeEach(() => {
69-
testStart.mockClear();
70-
testComplete.mockClear();
71-
partialComplete.mockClear();
72-
});
57+
let middlewareCalls: string[];
58+
59+
let testMiddleware1: TestMiddlewareMock;
60+
let testMiddleware2: TestMiddlewareMock;
61+
62+
beforeAll(async () => {
63+
middlewareCalls = [];
64+
65+
const createMiddlewareMock = (name: string): TestMiddlewareMock => {
66+
return jest.fn(({ currentStage, next }) => {
67+
middlewareCalls.push(`${currentStage}: ${name}`);
7368

74-
it("should run the test middleware", async () => {
75-
await runPipelineForStages(successfulStages, [testMiddleware]);
69+
return next();
70+
});
71+
};
7672

77-
expect(testStart).toHaveBeenCalledTimes(successfulStages.length);
78-
expect(testComplete).toHaveBeenCalledTimes(successfulStages.length);
73+
testMiddleware1 = createMiddlewareMock("testMiddleware1");
74+
testMiddleware2 = createMiddlewareMock("testMiddleware2");
75+
76+
await runPipelineForStages(successfulStages, [
77+
logStageMiddlewareFactory(),
78+
testMiddleware1,
79+
testMiddleware2,
80+
]);
7981
});
8082

81-
it("should run the partial middleware", async () => {
82-
await runPipelineForStages(successfulStages, [partialMiddleware]);
83+
it(`should run each middleware ${successfulStages.length} times`, () => {
84+
expect(testMiddleware1).toHaveBeenCalledTimes(successfulStages.length);
85+
expect(testMiddleware2).toHaveBeenCalledTimes(successfulStages.length);
86+
});
8387

84-
expect(partialComplete).toHaveBeenCalledTimes(successfulStages.length);
88+
it("should run middleware in the correct order", () => {
89+
expect(middlewareCalls).toEqual([
90+
"additionStage: testMiddleware1",
91+
"additionStage: testMiddleware2",
92+
"additionStage: testMiddleware1",
93+
"additionStage: testMiddleware2",
94+
"returnSumResult: testMiddleware1",
95+
"returnSumResult: testMiddleware2",
96+
"returnHistoryResult: testMiddleware1",
97+
"returnHistoryResult: testMiddleware2",
98+
]);
8599
});
86100
});
87101
});
88102

89103
function runPipelineForStages(
90104
stages: TestStage[],
91-
middleware: PipelineMiddleware[] = [],
105+
middleware: TestMiddleware[] = [],
92106
) {
93107
const pipeline = buildPipeline<
94108
TestPipelineArguments,
@@ -104,3 +118,8 @@ function runPipelineForStages(
104118

105119
return pipeline({ increment: INCREMENT });
106120
}
121+
122+
type TestMiddlewareMock = jest.Mock<
123+
ReturnType<TestMiddleware>,
124+
Parameters<TestMiddleware>
125+
>;

src/buildPipeline.ts

+33-59
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
import { compact, merge } from "lodash";
1+
import { merge } from "lodash";
22
import { PipelineError } from "./error/PipelineError";
3-
import {
3+
import type {
44
Pipeline,
55
PipelineInitializer,
66
PipelineMetadata,
77
PipelineMiddleware,
8-
PipelineMiddlewareCallable,
9-
PipelineMiddlewareEventType,
10-
PipelineMiddlewarePayload,
118
PipelineResultValidator,
129
PipelineStage,
1310
} from "./types";
@@ -21,7 +18,7 @@ interface BuildPipelineInput<
2118
initializer: PipelineInitializer<C, A>;
2219
stages: PipelineStage<A, C, R>[];
2320
resultsValidator: PipelineResultValidator<R>;
24-
middleware?: PipelineMiddleware[];
21+
middleware?: PipelineMiddleware<A, C, R>[];
2522
}
2623

2724
/**
@@ -36,7 +33,7 @@ export function buildPipeline<
3633
initializer,
3734
stages,
3835
resultsValidator,
39-
middleware = [],
36+
middleware: middlewares = [],
4037
}: BuildPipelineInput<A, C, R>): Pipeline<A, R> {
4138
return async (args) => {
4239
const results: Partial<R> = {};
@@ -55,78 +52,55 @@ export function buildPipeline<
5552
const context = await initializer(args);
5653
maybeContext = context;
5754

58-
const buildMiddlewarePayload = (
55+
const reversedMiddleware = [...middlewares].reverse();
56+
const wrapMiddleware = (
57+
middleware: PipelineMiddleware<A, C, R>,
5958
currentStage: string,
60-
): PipelineMiddlewarePayload<A, C, R> => ({
61-
context,
62-
metadata,
63-
results,
64-
stageNames,
65-
currentStage,
66-
});
59+
next: () => Promise<Partial<R>>,
60+
) => {
61+
return () => {
62+
return middleware({
63+
context,
64+
metadata,
65+
results,
66+
stageNames,
67+
currentStage,
68+
next,
69+
});
70+
};
71+
};
6772

6873
for (const stage of stages) {
69-
await executeMiddlewareForEvent(
70-
"onStageStart",
71-
middleware,
72-
buildMiddlewarePayload(stage.name),
73-
);
74+
// initialize next() with the stage itself
75+
let next = () => stage(context, metadata) as Promise<Partial<R>>;
76+
77+
// wrap stage with middleware such that the first middleware is the outermost function
78+
for (const middleware of reversedMiddleware) {
79+
next = wrapMiddleware(middleware, stage.name, next);
80+
}
7481

75-
const stageResults = await stage(context, metadata);
82+
// invoke middleware-wrapped stage
83+
const stageResults = await next();
7684

7785
// if the stage returns results, merge them onto the results object
7886
if (stageResults) {
7987
merge(results, stageResults);
8088
}
81-
82-
await executeMiddlewareForEvent(
83-
"onStageComplete",
84-
[...middleware].reverse(),
85-
buildMiddlewarePayload(stage.name),
86-
);
8789
}
8890

89-
if (!isValidResult(results, resultsValidator)) {
91+
if (!resultsValidator(results)) {
9092
throw new Error("Results from pipeline failed validation");
9193
}
9294

9395
return results;
94-
} catch (e) {
96+
} catch (cause) {
9597
throw new PipelineError(
96-
`${String(e)}`,
98+
String(cause),
9799
maybeContext,
98100
results,
99101
metadata,
100-
e,
102+
cause,
101103
);
102104
}
103105
};
104106
}
105-
106-
async function executeMiddlewareForEvent<
107-
A extends object,
108-
C extends object,
109-
R extends object,
110-
>(
111-
event: PipelineMiddlewareEventType,
112-
middleware: PipelineMiddleware[],
113-
payload: PipelineMiddlewarePayload<A, C, R>,
114-
) {
115-
const handlers = compact<PipelineMiddlewareCallable<object, object, object>>(
116-
middleware.map((m) => m[event]),
117-
);
118-
119-
for (const handler of handlers) {
120-
await handler(payload);
121-
}
122-
}
123-
124-
/**
125-
* Wraps the provided validator in a type guard
126-
*/
127-
function isValidResult<R extends object>(
128-
result: Partial<R>,
129-
validator: PipelineResultValidator<R>,
130-
): result is R {
131-
return validator(result);
132-
}
+14-9
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import { PipelineMiddleware } from "../types";
22

33
/**
4-
* A simple implementation of Pipeline middleware that logs when each stage begins and finishes
4+
* A simple implementation of Pipeline middleware that logs the duration of each stage
55
*/
66
export const logStageMiddlewareFactory = (
77
logger: (msg: string) => void = console.log,
8-
): PipelineMiddleware => ({
9-
onStageStart: ({ metadata, currentStage }) => {
10-
logger(`[${metadata.name}] starting ${currentStage}...`);
11-
},
12-
onStageComplete: ({ metadata, currentStage }) => {
13-
logger(`[${metadata.name}] ${currentStage} completed`);
14-
},
15-
});
8+
): PipelineMiddleware => {
9+
return async ({ metadata, currentStage, next }) => {
10+
const started = performance.now();
11+
12+
try {
13+
return await next();
14+
} finally {
15+
logger(
16+
`[${metadata.name}] ${currentStage} completed in ${performance.now() - started}ms`,
17+
);
18+
}
19+
};
20+
};

0 commit comments

Comments
 (0)