|
| 1 | +# This Software (Dioptra) is being made available as a public service by the |
| 2 | +# National Institute of Standards and Technology (NIST), an Agency of the United |
| 3 | +# States Department of Commerce. This software was developed in part by employees of |
| 4 | +# NIST and in part by NIST contractors. Copyright in portions of this software that |
| 5 | +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant |
| 6 | +# to Title 17 United States Code Section 105, works of NIST employees are not |
| 7 | +# subject to copyright protection in the United States. However, NIST may hold |
| 8 | +# international copyright in software created by its employees and domestic |
| 9 | +# copyright (or licensing rights) in portions of software that were assigned or |
| 10 | +# licensed to NIST. To the extent that NIST holds copyright in this software, it is |
| 11 | +# being made available under the Creative Commons Attribution 4.0 International |
| 12 | +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts |
| 13 | +# of the software developed or licensed by NIST. |
| 14 | +# |
| 15 | +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: |
| 16 | +# https://creativecommons.org/licenses/by/4.0/legalcode |
| 17 | +import hashlib |
| 18 | +import inspect |
| 19 | +from typing import Any |
| 20 | + |
| 21 | +import structlog |
| 22 | +from structlog.stdlib import BoundLogger |
| 23 | + |
| 24 | +EXPECTED_SERIALIZE_OPERATION_SHA256_HASH = "57241f0a33ed5e1771e5032d1e6f6994685185ed526b9ca2c70f4f27684d1f92" # noqa: B950; fmt: skip |
| 25 | +PATCHED_SERIALIZE_OPERATION_SHA256_HASH = "8a51bc04c8dcb81820548d9de53a9606faf0681ffc3684102744c69fbd076437" # noqa: B950; fmt: skip |
| 26 | + |
| 27 | +LOGGER: BoundLogger = structlog.stdlib.get_logger() |
| 28 | + |
| 29 | + |
| 30 | +def monkey_patch_flask_restx() -> None: |
| 31 | + """ |
| 32 | + Monkey patch flask_restx.Swagger.serialize_operation to force Swagger docs to use |
| 33 | + the multipart/form-data content type for multi-file uploads instead of the |
| 34 | + application/x-www-form-urlencoded content type. |
| 35 | +
|
| 36 | + This monkey-patch applies the proposed change in this PR |
| 37 | + https://github.com/python-restx/flask-restx/pull/542. |
| 38 | + """ |
| 39 | + import flask_restx |
| 40 | + from flask_restx.utils import not_none |
| 41 | + |
| 42 | + serialize_operation_sha256_hash = get_source_code_hash( |
| 43 | + flask_restx.Swagger.serialize_operation |
| 44 | + ) |
| 45 | + |
| 46 | + if serialize_operation_sha256_hash == PATCHED_SERIALIZE_OPERATION_SHA256_HASH: |
| 47 | + LOGGER.debug( |
| 48 | + "flask_restx.Swagger.serialize_operation already patched", |
| 49 | + sha256_hash=serialize_operation_sha256_hash, |
| 50 | + ) |
| 51 | + return None |
| 52 | + |
| 53 | + if serialize_operation_sha256_hash != EXPECTED_SERIALIZE_OPERATION_SHA256_HASH: |
| 54 | + LOGGER.error( |
| 55 | + "Source code hash changed", |
| 56 | + reason="hash of flask_restx.Swagger.serialize_operation did not match", |
| 57 | + expected_hash=EXPECTED_SERIALIZE_OPERATION_SHA256_HASH, |
| 58 | + sha256_hash=serialize_operation_sha256_hash, |
| 59 | + ) |
| 60 | + raise RuntimeError( |
| 61 | + "Source code hash changed (reason: hash of " |
| 62 | + "flask_restx.Swagger.serialize_operation did not match " |
| 63 | + f"{EXPECTED_SERIALIZE_OPERATION_SHA256_HASH}): " |
| 64 | + f"{serialize_operation_sha256_hash}" |
| 65 | + ) |
| 66 | + |
| 67 | + def serialize_operation_patched(self, doc, method): |
| 68 | + operation = { |
| 69 | + "responses": self.responses_for(doc, method) or None, |
| 70 | + "summary": doc[method]["docstring"]["summary"], |
| 71 | + "description": self.description_for(doc, method) or None, |
| 72 | + "operationId": self.operation_id_for(doc, method), |
| 73 | + "parameters": self.parameters_for(doc[method]) or None, |
| 74 | + "security": self.security_for(doc, method), |
| 75 | + } |
| 76 | + # Handle 'produces' mimetypes documentation |
| 77 | + if "produces" in doc[method]: |
| 78 | + operation["produces"] = doc[method]["produces"] |
| 79 | + # Handle deprecated annotation |
| 80 | + if doc.get("deprecated") or doc[method].get("deprecated"): |
| 81 | + operation["deprecated"] = True |
| 82 | + # Handle form exceptions: |
| 83 | + doc_params = list(doc.get("params", {}).values()) |
| 84 | + all_params = doc_params + (operation["parameters"] or []) |
| 85 | + if all_params and any(p["in"] == "formData" for p in all_params): |
| 86 | + if any(p["type"] == "file" for p in all_params): |
| 87 | + operation["consumes"] = ["multipart/form-data"] |
| 88 | + elif any( |
| 89 | + p["type"] == "array" and p["collectionFormat"] == "multi" |
| 90 | + for p in all_params |
| 91 | + if "collectionFormat" in p |
| 92 | + ): |
| 93 | + operation["consumes"] = ["multipart/form-data"] |
| 94 | + else: |
| 95 | + operation["consumes"] = [ |
| 96 | + "application/x-www-form-urlencoded", |
| 97 | + "multipart/form-data", |
| 98 | + ] |
| 99 | + operation.update(self.vendor_fields(doc, method)) |
| 100 | + return not_none(operation) |
| 101 | + |
| 102 | + flask_restx.Swagger.serialize_operation = serialize_operation_patched |
| 103 | + LOGGER.info("flask_restx.Swagger.serialize_operation patched successfully") |
| 104 | + |
| 105 | + |
| 106 | +def get_source_code_hash(obj: Any) -> str: |
| 107 | + """Generate a hash of the underlying source code of a Python object. |
| 108 | +
|
| 109 | + Args: |
| 110 | + obj: The Python object for which to generate a source code hash. |
| 111 | +
|
| 112 | + Returns: |
| 113 | + The hash of the source code of the Python object. |
| 114 | + """ |
| 115 | + |
| 116 | + hash_sha256 = hashlib.sha256() |
| 117 | + source_lines, _ = inspect.getsourcelines(obj) |
| 118 | + source_lines = [line.rstrip() for line in source_lines] |
| 119 | + |
| 120 | + for line in source_lines: |
| 121 | + hash_sha256.update(line.encode("utf-8")) |
| 122 | + |
| 123 | + return hash_sha256.hexdigest() |
0 commit comments