diff --git a/.gitignore b/.gitignore index bd6ad262..cb854d24 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # PyCharm .idea/ + +# Visual Studio Code +.vscode/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 718a5aa9..d6f1c122 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,37 +1,52 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 - hooks: - - id: trailing-whitespace - - id: check-added-large-files - - id: end-of-file-fixer - - id: mixed-line-ending - args: ["--fix=lf"] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + - id: end-of-file-fixer + - id: mixed-line-ending + args: ["--fix=lf"] - - repo: https://github.com/pre-commit/mirrors-isort - rev: v5.8.0 - hooks: - - id: isort - args: - [ - "--multi-line=3", - "--trailing-comma", - "--force-grid-wrap=0", - "--use-parentheses", - "--line-width=88", - ] + - repo: https://github.com/pycqa/isort + rev: 5.10.1 + hooks: + - id: isort + args: + [ + "--profile", + "black", + "--multi-line=3", + "--trailing-comma", + "--force-grid-wrap=0", + "--use-parentheses", + "--line-width=88", + ] - - repo: https://github.com/humitos/mirrors-autoflake.git - rev: v1.1 - hooks: - - id: autoflake - args: ["--in-place", "--remove-all-unused-imports"] + - repo: https://github.com/myint/autoflake.git + rev: v1.4 + hooks: + - id: autoflake + args: + [ + "--in-place", + "--remove-all-unused-imports", + "--ignore-init-module-imports", + ] - - repo: https://github.com/psf/black - rev: 21.9b0 - hooks: - - id: black - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v2.4.1 - hooks: - - id: prettier + - repo: https://github.com/ambv/black + rev: 21.11b1 + hooks: + - id: black + + - repo: https://github.com/asottile/pyupgrade + rev: v2.29.1 + hooks: + - id: pyupgrade + args: ["--py37-plus", "--keep-runtime-typing"] + + - repo: https://github.com/commitizen-tools/commitizen + rev: v2.20.0 + hooks: + - id: commitizen + stages: [commit-msg] diff --git a/postgrest_py/__init__.py b/postgrest_py/__init__.py index 8505e96f..eadf2ebd 100644 --- a/postgrest_py/__init__.py +++ b/postgrest_py/__init__.py @@ -1,13 +1,17 @@ -from postgrest_py._async.client import AsyncPostgrestClient # noqa: F401 -from postgrest_py._async.request_builder import AsyncFilterRequestBuilder # noqa: F401 -from postgrest_py._async.request_builder import AsyncQueryRequestBuilder # noqa: F401 -from postgrest_py._async.request_builder import AsyncRequestBuilder # noqa: F401 -from postgrest_py._async.request_builder import AsyncSelectRequestBuilder # noqa: F401 -from postgrest_py._sync.client import SyncPostgrestClient # noqa: F401 -from postgrest_py._sync.request_builder import SyncFilterRequestBuilder # noqa: F401 -from postgrest_py._sync.request_builder import SyncQueryRequestBuilder # noqa: F401 -from postgrest_py._sync.request_builder import SyncRequestBuilder # noqa: F401 -from postgrest_py._sync.request_builder import SyncSelectRequestBuilder # noqa: F401 -from postgrest_py.config import DEFAULT_POSTGREST_CLIENT_HEADERS # noqa: F401 -from postgrest_py.deprecated_client import Client, PostgrestClient # noqa: F401 -from postgrest_py.deprecated_get_request_builder import GetRequestBuilder # noqa: F401 +from postgrest_py._async.client import AsyncPostgrestClient +from postgrest_py._async.request_builder import ( + AsyncFilterRequestBuilder, + AsyncQueryRequestBuilder, + AsyncRequestBuilder, + AsyncSelectRequestBuilder, +) +from postgrest_py._sync.client import SyncPostgrestClient +from postgrest_py._sync.request_builder import ( + SyncFilterRequestBuilder, + SyncQueryRequestBuilder, + SyncRequestBuilder, + SyncSelectRequestBuilder, +) +from postgrest_py.base_client import DEFAULT_POSTGREST_CLIENT_HEADERS +from postgrest_py.deprecated_client import Client, PostgrestClient +from postgrest_py.deprecated_get_request_builder import GetRequestBuilder diff --git a/postgrest_py/_async/client.py b/postgrest_py/_async/client.py index e1c11dc8..aa0a90ce 100644 --- a/postgrest_py/_async/client.py +++ b/postgrest_py/_async/client.py @@ -1,16 +1,19 @@ -from typing import Dict, Optional, Union +from typing import Dict, cast from deprecation import deprecated -from httpx import BasicAuth, Response +from httpx import Response from postgrest_py.__version__ import __version__ -from postgrest_py.config import DEFAULT_POSTGREST_CLIENT_HEADERS +from postgrest_py.base_client import ( + DEFAULT_POSTGREST_CLIENT_HEADERS, + BasePostgrestClient, +) from postgrest_py.utils import AsyncClient from .request_builder import AsyncRequestBuilder -class AsyncPostgrestClient: +class AsyncPostgrestClient(BasePostgrestClient): """PostgREST client.""" def __init__( @@ -20,14 +23,17 @@ def __init__( schema: str = "public", headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS, ) -> None: - headers = { - **headers, - "Accept-Profile": schema, - "Content-Profile": schema, - } - self.session = AsyncClient(base_url=base_url, headers=headers) + BasePostgrestClient.__init__(self, base_url, schema=schema, headers=headers) + self.session = cast(AsyncClient, self.session) - async def __aenter__(self): + def create_session( + self, + base_url: str, + headers: Dict[str, str], + ) -> AsyncClient: + return AsyncClient(base_url=base_url, headers=headers) + + async def __aenter__(self) -> "AsyncPostgrestClient": return self async def __aexit__(self, exc_type, exc, tb) -> None: @@ -36,41 +42,17 @@ async def __aexit__(self, exc_type, exc, tb) -> None: async def aclose(self) -> None: await self.session.aclose() - def auth( - self, - token: Optional[str], - *, - username: Union[str, bytes, None] = None, - password: Union[str, bytes] = "", - ): - """ - Authenticate the client with either bearer token or basic authentication. - - Raise `ValueError` if neither authentication scheme is provided. - Bearer token is preferred if both ones are provided. - """ - if token: - self.session.headers["Authorization"] = f"Bearer {token}" - elif username: - self.session.auth = BasicAuth(username, password) - else: - raise ValueError( - "Neither bearer token or basic authentication scheme is provided" - ) - return self - - def schema(self, schema: str): - """Switch to another schema.""" - self.session.headers.update({"Accept-Profile": schema, "Content-Profile": schema}) - return self - def from_(self, table: str) -> AsyncRequestBuilder: """Perform a table operation.""" return AsyncRequestBuilder(self.session, f"/{table}") + + def table(self, table: str) -> AsyncRequestBuilder: + """Alias to self.from_().""" + return self.from_(table) - @deprecated("0.2.0", "1.0.0", __version__, "Use PostgrestClient.from_() instead") + @deprecated("0.2.0", "1.0.0", __version__, "Use self.from_() instead") def from_table(self, table: str) -> AsyncRequestBuilder: - """Alias to Self.from_().""" + """Alias to self.from_().""" return self.from_(table) async def rpc(self, func: str, params: dict) -> Response: diff --git a/postgrest_py/_async/request_builder.py b/postgrest_py/_async/request_builder.py index 9c6ce9c8..c058940f 100644 --- a/postgrest_py/_async/request_builder.py +++ b/postgrest_py/_async/request_builder.py @@ -1,205 +1,119 @@ -import re -import sys -from typing import Any, Dict, Iterable, Optional, Tuple, Union - -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -from postgrest_py.utils import AsyncClient, sanitize_param, sanitize_pattern_param - -CountMethod = Union[Literal["exact"], Literal["planned"], Literal["estimated"]] - - -class AsyncRequestBuilder: - def __init__(self, session: AsyncClient, path: str): - self.session = session - self.path = path - - def select(self, *columns: str, count: Optional[CountMethod] = None): - if columns: - method = "GET" - self.session.params = self.session.params.set("select", ",".join(columns)) - else: - method = "HEAD" - - if count: - self.session.headers["Prefer"] = f"count={count}" - - return AsyncSelectRequestBuilder(self.session, self.path, method, {}) - - def insert(self, json: dict, *, count: Optional[CountMethod] = None, upsert=False): - prefer_headers = ["return=representation"] - if count: - prefer_headers.append(f"count={count}") - if upsert: - prefer_headers.append("resolution=merge-duplicates") - self.session.headers["prefer"] = ",".join(prefer_headers) - return AsyncQueryRequestBuilder(self.session, self.path, "POST", json) - - def update(self, json: dict, *, count: Optional[CountMethod] = None): - prefer_headers = ["return=representation"] - if count: - prefer_headers.append(f"count={count}") - self.session.headers["prefer"] = ",".join(prefer_headers) - return AsyncFilterRequestBuilder(self.session, self.path, "PATCH", json) - - def delete(self, *, count: Optional[CountMethod] = None): - prefer_headers = ["return=representation"] - if count: - prefer_headers.append(f"count={count}") - self.session.headers["prefer"] = ",".join(prefer_headers) - return AsyncFilterRequestBuilder(self.session, self.path, "DELETE", {}) +from typing import Any, Optional, Tuple + +from postgrest_py.base_request_builder import ( + BaseFilterRequestBuilder, + BaseSelectRequestBuilder, + CountMethod, + pre_delete, + pre_insert, + pre_select, + pre_update, + pre_upsert, + process_response, +) +from postgrest_py.utils import AsyncClient class AsyncQueryRequestBuilder: - def __init__(self, session: AsyncClient, path: str, http_method: str, json: dict): + def __init__( + self, + session: AsyncClient, + path: str, + http_method: str, + json: dict, + ) -> None: self.session = session self.path = path self.http_method = http_method self.json = json async def execute(self) -> Tuple[Any, Optional[int]]: - r = await self.session.request(self.http_method, self.path, json=self.json) - - count = None - try: - count_header_match = re.search( - "count=(exact|planned|estimated)", self.session.headers["prefer"] - ) - content_range = r.headers["content-range"].split("/") - if count_header_match and len(content_range) >= 2: - count = int(content_range[1]) - except KeyError: - ... - - return r.json(), count - - -class AsyncFilterRequestBuilder(AsyncQueryRequestBuilder): - def __init__(self, session: AsyncClient, path: str, http_method: str, json: dict): - super().__init__(session, path, http_method, json) - - self.negate_next = False - - @property - def not_(self): - self.negate_next = True - return self - - def filter(self, column: str, operator: str, criteria: str): - """Either filter in or filter out based on Self.negate_next.""" - if self.negate_next is True: - self.negate_next = False - operator = f"not.{operator}" - key, val = sanitize_param(column), f"{operator}.{criteria}" - self.session.params = self.session.params.add(key, val) - return self - - def eq(self, column: str, value: str): - return self.filter(column, "eq", sanitize_param(value)) - - def neq(self, column: str, value: str): - return self.filter(column, "neq", sanitize_param(value)) - - def gt(self, column: str, value: str): - return self.filter(column, "gt", sanitize_param(value)) - - def gte(self, column: str, value: str): - return self.filter(column, "gte", sanitize_param(value)) - - def lt(self, column: str, value: str): - return self.filter(column, "lt", sanitize_param(value)) + r = await self.session.request( + self.http_method, + self.path, + json=self.json, + ) + return process_response(self.session, r) + + +class AsyncFilterRequestBuilder(BaseFilterRequestBuilder, AsyncQueryRequestBuilder): + def __init__( + self, + session: AsyncClient, + path: str, + http_method: str, + json: dict, + ) -> None: + BaseFilterRequestBuilder.__init__(self, session) + AsyncQueryRequestBuilder.__init__(self, session, path, http_method, json) + + +class AsyncSelectRequestBuilder(BaseSelectRequestBuilder, AsyncQueryRequestBuilder): + def __init__( + self, + session: AsyncClient, + path: str, + http_method: str, + json: dict, + ) -> None: + BaseSelectRequestBuilder.__init__(self, session) + AsyncQueryRequestBuilder.__init__(self, session, path, http_method, json) - def lte(self, column: str, value: str): - return self.filter(column, "lte", sanitize_param(value)) - def is_(self, column: str, value: str): - return self.filter(column, "is", sanitize_param(value)) - - def like(self, column: str, pattern: str): - return self.filter(column, "like", sanitize_pattern_param(pattern)) - - def ilike(self, column: str, pattern: str): - return self.filter(column, "ilike", sanitize_pattern_param(pattern)) - - def fts(self, column: str, query: str): - return self.filter(column, "fts", sanitize_param(query)) - - def plfts(self, column: str, query: str): - return self.filter(column, "plfts", sanitize_param(query)) - - def phfts(self, column: str, query: str): - return self.filter(column, "phfts", sanitize_param(query)) - - def wfts(self, column: str, query: str): - return self.filter(column, "wfts", sanitize_param(query)) - - def in_(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "in", f"({values})") - - def cs(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "cs", f"{{{values}}}") - - def cd(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "cd", f"{{{values}}}") - - def ov(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "ov", f"{{{values}}}") - - def sl(self, column: str, range: Tuple[int, int]): - return self.filter(column, "sl", f"({range[0]},{range[1]})") - - def sr(self, column: str, range: Tuple[int, int]): - return self.filter(column, "sr", f"({range[0]},{range[1]})") - - def nxl(self, column: str, range: Tuple[int, int]): - return self.filter(column, "nxl", f"({range[0]},{range[1]})") - - def nxr(self, column: str, range: Tuple[int, int]): - return self.filter(column, "nxr", f"({range[0]},{range[1]})") - - def adj(self, column: str, range: Tuple[int, int]): - return self.filter(column, "adj", f"({range[0]},{range[1]})") - - def match(self, query: Dict[str, Any]): - updated_query = None - for key in query: - value = query.get(key, "") - updated_query = self.eq(key, value) - return updated_query - - -class AsyncSelectRequestBuilder(AsyncFilterRequestBuilder): - def __init__(self, session: AsyncClient, path: str, http_method: str, json: dict): - super().__init__(session, path, http_method, json) - - def order(self, column: str, *, desc=False, nullsfirst=False): - self.session.params[ - "order" - ] = f"{column}{'.desc' if desc else ''}{'.nullsfirst' if nullsfirst else ''}" - - return self - - def limit(self, size: int, *, start=0): - self.session.headers["Range-Unit"] = "items" - self.session.headers["Range"] = f"{start}-{start + size - 1}" - return self - - def range(self, start: int, end: int): - self.session.headers["Range-Unit"] = "items" - self.session.headers["Range"] = f"{start}-{end - 1}" - return self +class AsyncRequestBuilder: + def __init__(self, session: AsyncClient, path: str) -> None: + self.session = session + self.path = path - def single(self): - self.session.headers["Accept"] = "application/vnd.pgrst.object+json" - return self + def select( + self, + *columns: str, + count: Optional[CountMethod] = None, + ) -> AsyncSelectRequestBuilder: + method, json = pre_select(self.session, self.path, *columns, count=count) + return AsyncSelectRequestBuilder(self.session, self.path, method, json) + + def insert( + self, + json: dict, + *, + count: Optional[CountMethod] = None, + upsert=False, + ) -> AsyncQueryRequestBuilder: + method, json = pre_insert( + self.session, + self.path, + json, + count=count, + upsert=upsert, + ) + return AsyncQueryRequestBuilder(self.session, self.path, method, json) + + def upsert( + self, + json: dict, + *, + count: Optional[CountMethod] = None, + ignore_duplicates=False, + ) -> AsyncQueryRequestBuilder: + method, json = pre_upsert( + self.session, + self.path, + json, + count=count, + ignore_duplicates=ignore_duplicates, + ) + return AsyncQueryRequestBuilder(self.session, self.path, method, json) + + def update( + self, + json: dict, + *, + count: Optional[CountMethod] = None, + ) -> AsyncFilterRequestBuilder: + method, json = pre_update(self.session, self.path, json, count=count) + return AsyncFilterRequestBuilder(self.session, self.path, method, json) + + def delete(self, *, count: Optional[CountMethod] = None) -> AsyncFilterRequestBuilder: + method, json = pre_delete(self.session, self.path, count=count) + return AsyncFilterRequestBuilder(self.session, self.path, method, json) diff --git a/postgrest_py/_sync/client.py b/postgrest_py/_sync/client.py index 898d4c6e..192271cd 100644 --- a/postgrest_py/_sync/client.py +++ b/postgrest_py/_sync/client.py @@ -1,16 +1,19 @@ -from typing import Dict, Optional, Union +from typing import Dict, cast from deprecation import deprecated -from httpx import BasicAuth, Response +from httpx import Response from postgrest_py.__version__ import __version__ -from postgrest_py.config import DEFAULT_POSTGREST_CLIENT_HEADERS +from postgrest_py.base_client import ( + DEFAULT_POSTGREST_CLIENT_HEADERS, + BasePostgrestClient, +) from postgrest_py.utils import SyncClient from .request_builder import SyncRequestBuilder -class SyncPostgrestClient: +class SyncPostgrestClient(BasePostgrestClient): """PostgREST client.""" def __init__( @@ -20,14 +23,17 @@ def __init__( schema: str = "public", headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS, ) -> None: - headers = { - **headers, - "Accept-Profile": schema, - "Content-Profile": schema, - } - self.session = SyncClient(base_url=base_url, headers=headers) + BasePostgrestClient.__init__(self, base_url, schema=schema, headers=headers) + self.session = cast(SyncClient, self.session) - def __enter__(self): + def create_session( + self, + base_url: str, + headers: Dict[str, str], + ) -> SyncClient: + return SyncClient(base_url=base_url, headers=headers) + + def __enter__(self) -> "SyncPostgrestClient": return self def __exit__(self, exc_type, exc, tb) -> None: @@ -36,41 +42,17 @@ def __exit__(self, exc_type, exc, tb) -> None: def aclose(self) -> None: self.session.aclose() - def auth( - self, - token: Optional[str], - *, - username: Union[str, bytes, None] = None, - password: Union[str, bytes] = "", - ): - """ - Authenticate the client with either bearer token or basic authentication. - - Raise `ValueError` if neither authentication scheme is provided. - Bearer token is preferred if both ones are provided. - """ - if token: - self.session.headers["Authorization"] = f"Bearer {token}" - elif username: - self.session.auth = BasicAuth(username, password) - else: - raise ValueError( - "Neither bearer token or basic authentication scheme is provided" - ) - return self - - def schema(self, schema: str): - """Switch to another schema.""" - self.session.headers.update({"Accept-Profile": schema, "Content-Profile": schema}) - return self - def from_(self, table: str) -> SyncRequestBuilder: """Perform a table operation.""" return SyncRequestBuilder(self.session, f"/{table}") + + def table(self, table: str) -> SyncRequestBuilder: + """Alias to self.from_().""" + return self.from_(table) - @deprecated("0.2.0", "1.0.0", __version__, "Use PostgrestClient.from_() instead") + @deprecated("0.2.0", "1.0.0", __version__, "Use self.from_() instead") def from_table(self, table: str) -> SyncRequestBuilder: - """Alias to Self.from_().""" + """Alias to self.from_().""" return self.from_(table) def rpc(self, func: str, params: dict) -> Response: diff --git a/postgrest_py/_sync/request_builder.py b/postgrest_py/_sync/request_builder.py index 08c784b1..01cd3a97 100644 --- a/postgrest_py/_sync/request_builder.py +++ b/postgrest_py/_sync/request_builder.py @@ -1,205 +1,119 @@ -import re -import sys -from typing import Any, Dict, Iterable, Optional, Tuple, Union - -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -from postgrest_py.utils import SyncClient, sanitize_param, sanitize_pattern_param - -CountMethod = Union[Literal["exact"], Literal["planned"], Literal["estimated"]] - - -class SyncRequestBuilder: - def __init__(self, session: SyncClient, path: str): - self.session = session - self.path = path - - def select(self, *columns: str, count: Optional[CountMethod] = None): - if columns: - method = "GET" - self.session.params = self.session.params.set("select", ",".join(columns)) - else: - method = "HEAD" - - if count: - self.session.headers["Prefer"] = f"count={count}" - - return SyncSelectRequestBuilder(self.session, self.path, method, {}) - - def insert(self, json: dict, *, count: Optional[CountMethod] = None, upsert=False): - prefer_headers = ["return=representation"] - if count: - prefer_headers.append(f"count={count}") - if upsert: - prefer_headers.append("resolution=merge-duplicates") - self.session.headers["prefer"] = ",".join(prefer_headers) - return SyncQueryRequestBuilder(self.session, self.path, "POST", json) - - def update(self, json: dict, *, count: Optional[CountMethod] = None): - prefer_headers = ["return=representation"] - if count: - prefer_headers.append(f"count={count}") - self.session.headers["prefer"] = ",".join(prefer_headers) - return SyncFilterRequestBuilder(self.session, self.path, "PATCH", json) - - def delete(self, *, count: Optional[CountMethod] = None): - prefer_headers = ["return=representation"] - if count: - prefer_headers.append(f"count={count}") - self.session.headers["prefer"] = ",".join(prefer_headers) - return SyncFilterRequestBuilder(self.session, self.path, "DELETE", {}) +from typing import Any, Optional, Tuple + +from postgrest_py.base_request_builder import ( + BaseFilterRequestBuilder, + BaseSelectRequestBuilder, + CountMethod, + pre_delete, + pre_insert, + pre_select, + pre_update, + pre_upsert, + process_response, +) +from postgrest_py.utils import SyncClient class SyncQueryRequestBuilder: - def __init__(self, session: SyncClient, path: str, http_method: str, json: dict): + def __init__( + self, + session: SyncClient, + path: str, + http_method: str, + json: dict, + ) -> None: self.session = session self.path = path self.http_method = http_method self.json = json def execute(self) -> Tuple[Any, Optional[int]]: - r = self.session.request(self.http_method, self.path, json=self.json) - - count = None - try: - count_header_match = re.search( - "count=(exact|planned|estimated)", self.session.headers["prefer"] - ) - content_range = r.headers["content-range"].split("/") - if count_header_match and len(content_range) >= 2: - count = int(content_range[1]) - except KeyError: - ... - - return r.json(), count - - -class SyncFilterRequestBuilder(SyncQueryRequestBuilder): - def __init__(self, session: SyncClient, path: str, http_method: str, json: dict): - super().__init__(session, path, http_method, json) - - self.negate_next = False - - @property - def not_(self): - self.negate_next = True - return self - - def filter(self, column: str, operator: str, criteria: str): - """Either filter in or filter out based on Self.negate_next.""" - if self.negate_next is True: - self.negate_next = False - operator = f"not.{operator}" - key, val = sanitize_param(column), f"{operator}.{criteria}" - self.session.params = self.session.params.add(key, val) - return self - - def eq(self, column: str, value: str): - return self.filter(column, "eq", sanitize_param(value)) - - def neq(self, column: str, value: str): - return self.filter(column, "neq", sanitize_param(value)) - - def gt(self, column: str, value: str): - return self.filter(column, "gt", sanitize_param(value)) - - def gte(self, column: str, value: str): - return self.filter(column, "gte", sanitize_param(value)) - - def lt(self, column: str, value: str): - return self.filter(column, "lt", sanitize_param(value)) + r = self.session.request( + self.http_method, + self.path, + json=self.json, + ) + return process_response(self.session, r) + + +class SyncFilterRequestBuilder(BaseFilterRequestBuilder, SyncQueryRequestBuilder): + def __init__( + self, + session: SyncClient, + path: str, + http_method: str, + json: dict, + ) -> None: + BaseFilterRequestBuilder.__init__(self, session) + SyncQueryRequestBuilder.__init__(self, session, path, http_method, json) + + +class SyncSelectRequestBuilder(BaseSelectRequestBuilder, SyncQueryRequestBuilder): + def __init__( + self, + session: SyncClient, + path: str, + http_method: str, + json: dict, + ) -> None: + BaseSelectRequestBuilder.__init__(self, session) + SyncQueryRequestBuilder.__init__(self, session, path, http_method, json) - def lte(self, column: str, value: str): - return self.filter(column, "lte", sanitize_param(value)) - def is_(self, column: str, value: str): - return self.filter(column, "is", sanitize_param(value)) - - def like(self, column: str, pattern: str): - return self.filter(column, "like", sanitize_pattern_param(pattern)) - - def ilike(self, column: str, pattern: str): - return self.filter(column, "ilike", sanitize_pattern_param(pattern)) - - def fts(self, column: str, query: str): - return self.filter(column, "fts", sanitize_param(query)) - - def plfts(self, column: str, query: str): - return self.filter(column, "plfts", sanitize_param(query)) - - def phfts(self, column: str, query: str): - return self.filter(column, "phfts", sanitize_param(query)) - - def wfts(self, column: str, query: str): - return self.filter(column, "wfts", sanitize_param(query)) - - def in_(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "in", f"({values})") - - def cs(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "cs", f"{{{values}}}") - - def cd(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "cd", f"{{{values}}}") - - def ov(self, column: str, values: Iterable[str]): - values = map(sanitize_param, values) - values = ",".join(values) - return self.filter(column, "ov", f"{{{values}}}") - - def sl(self, column: str, range: Tuple[int, int]): - return self.filter(column, "sl", f"({range[0]},{range[1]})") - - def sr(self, column: str, range: Tuple[int, int]): - return self.filter(column, "sr", f"({range[0]},{range[1]})") - - def nxl(self, column: str, range: Tuple[int, int]): - return self.filter(column, "nxl", f"({range[0]},{range[1]})") - - def nxr(self, column: str, range: Tuple[int, int]): - return self.filter(column, "nxr", f"({range[0]},{range[1]})") - - def adj(self, column: str, range: Tuple[int, int]): - return self.filter(column, "adj", f"({range[0]},{range[1]})") - - def match(self, query: Dict[str, Any]): - updated_query = None - for key in query: - value = query.get(key, "") - updated_query = self.eq(key, value) - return updated_query - - -class SyncSelectRequestBuilder(SyncFilterRequestBuilder): - def __init__(self, session: SyncClient, path: str, http_method: str, json: dict): - super().__init__(session, path, http_method, json) - - def order(self, column: str, *, desc=False, nullsfirst=False): - self.session.params[ - "order" - ] = f"{column}{'.desc' if desc else ''}{'.nullsfirst' if nullsfirst else ''}" - - return self - - def limit(self, size: int, *, start=0): - self.session.headers["Range-Unit"] = "items" - self.session.headers["Range"] = f"{start}-{start + size - 1}" - return self - - def range(self, start: int, end: int): - self.session.headers["Range-Unit"] = "items" - self.session.headers["Range"] = f"{start}-{end - 1}" - return self +class SyncRequestBuilder: + def __init__(self, session: SyncClient, path: str) -> None: + self.session = session + self.path = path - def single(self): - self.session.headers["Accept"] = "application/vnd.pgrst.object+json" - return self + def select( + self, + *columns: str, + count: Optional[CountMethod] = None, + ) -> SyncSelectRequestBuilder: + method, json = pre_select(self.session, self.path, *columns, count=count) + return SyncSelectRequestBuilder(self.session, self.path, method, json) + + def insert( + self, + json: dict, + *, + count: Optional[CountMethod] = None, + upsert=False, + ) -> SyncQueryRequestBuilder: + method, json = pre_insert( + self.session, + self.path, + json, + count=count, + upsert=upsert, + ) + return SyncQueryRequestBuilder(self.session, self.path, method, json) + + def upsert( + self, + json: dict, + *, + count: Optional[CountMethod] = None, + ignore_duplicates=False, + ) -> SyncQueryRequestBuilder: + method, json = pre_upsert( + self.session, + self.path, + json, + count=count, + ignore_duplicates=ignore_duplicates, + ) + return SyncQueryRequestBuilder(self.session, self.path, method, json) + + def update( + self, + json: dict, + *, + count: Optional[CountMethod] = None, + ) -> SyncFilterRequestBuilder: + method, json = pre_update(self.session, self.path, json, count=count) + return SyncFilterRequestBuilder(self.session, self.path, method, json) + + def delete(self, *, count: Optional[CountMethod] = None) -> SyncFilterRequestBuilder: + method, json = pre_delete(self.session, self.path, count=count) + return SyncFilterRequestBuilder(self.session, self.path, method, json) diff --git a/postgrest_py/base_client.py b/postgrest_py/base_client.py new file mode 100644 index 00000000..dfe2078f --- /dev/null +++ b/postgrest_py/base_client.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional, Union + +from httpx import BasicAuth + +from postgrest_py.utils import AsyncClient, SyncClient + +DEFAULT_POSTGREST_CLIENT_HEADERS: Dict[str, str] = { + "Accept": "application/json", + "Content-Type": "application/json", +} + + +class BasePostgrestClient(ABC): + """Base PostgREST client.""" + + def __init__(self, base_url: str, *, schema: str, headers: Dict[str, str]) -> None: + headers = { + **headers, + "Accept-Profile": schema, + "Content-Profile": schema, + } + self.session = self.create_session(base_url, headers) + + @abstractmethod + def create_session( + self, + base_url: str, + headers: Dict[str, str], + ) -> Union[SyncClient, AsyncClient]: + raise NotImplementedError() + + def auth( + self, + token: Optional[str], + *, + username: Union[str, bytes, None] = None, + password: Union[str, bytes] = "", + ): + """ + Authenticate the client with either bearer token or basic authentication. + Raise `ValueError` if neither authentication scheme is provided. + Bearer token is preferred if both ones are provided. + """ + if token: + self.session.headers["Authorization"] = f"Bearer {token}" + elif username: + self.session.auth = BasicAuth(username, password) + else: + raise ValueError( + "Neither bearer token or basic authentication scheme is provided" + ) + return self + + def schema(self, schema: str): + """Switch to another schema.""" + self.session.headers.update({"Accept-Profile": schema, "Content-Profile": schema}) + return self diff --git a/postgrest_py/base_request_builder.py b/postgrest_py/base_request_builder.py new file mode 100644 index 00000000..0ebbb893 --- /dev/null +++ b/postgrest_py/base_request_builder.py @@ -0,0 +1,232 @@ +from re import search +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +from httpx import Response + +from postgrest_py.constants import CountMethod, Filters, RequestMethod +from postgrest_py.utils import ( + AsyncClient, + SyncClient, + sanitize_param, + sanitize_pattern_param, +) + + +def pre_select( + session: Union[AsyncClient, SyncClient], + path: str, + *columns: str, + count: Optional[CountMethod] = None, +) -> Tuple[RequestMethod, dict]: + if columns: + method = RequestMethod.GET + session.params = session.params.set("select", ",".join(columns)) + else: + method = RequestMethod.HEAD + if count: + session.headers["Prefer"] = f"count={count}" + return method, {} + + +def pre_insert( + session: Union[AsyncClient, SyncClient], + path: str, + json: dict, + *, + count: Optional[CountMethod] = None, + upsert=False, +) -> Tuple[RequestMethod, dict]: + prefer_headers = ["return=representation"] + if count: + prefer_headers.append(f"count={count}") + if upsert: + prefer_headers.append("resolution=merge-duplicates") + session.headers["prefer"] = ",".join(prefer_headers) + return RequestMethod.POST, json + + +def pre_upsert( + session: Union[AsyncClient, SyncClient], + path: str, + json: dict, + *, + count: Optional[CountMethod] = None, + ignore_duplicates=False, +) -> Tuple[RequestMethod, dict]: + prefer_headers = ["return=representation"] + if count: + prefer_headers.append(f"count={count}") + resolution = "ignore" if ignore_duplicates else "merge" + prefer_headers.append(f"resolution={resolution}-duplicates") + session.headers["prefer"] = ",".join(prefer_headers) + return RequestMethod.POST, json + + +def pre_update( + session: Union[AsyncClient, SyncClient], + path: str, + json: dict, + *, + count: Optional[CountMethod] = None, +) -> Tuple[RequestMethod, dict]: + prefer_headers = ["return=representation"] + if count: + prefer_headers.append(f"count={count}") + session.headers["prefer"] = ",".join(prefer_headers) + return RequestMethod.PATCH, json + + +def pre_delete( + session: Union[AsyncClient, SyncClient], + path: str, + *, + count: Optional[CountMethod] = None, +) -> Tuple[RequestMethod, dict]: + prefer_headers = ["return=representation"] + if count: + prefer_headers.append(f"count={count}") + session.headers["prefer"] = ",".join(prefer_headers) + return RequestMethod.DELETE, {} + + +def process_response( + session: Union[AsyncClient, SyncClient], + r: Response, +) -> Tuple[Any, Optional[int]]: + count = None + prefer_header = session.headers.get("prefer") + if prefer_header: + pattern = f"count=({'|'.join([cm.value for cm in CountMethod])})" + count_header_match = search(pattern, prefer_header) + content_range_header = r.headers.get("content-range") + if count_header_match and content_range_header: + content_range = content_range_header.split("/") + if len(content_range) >= 2: + count = int(content_range[1]) + return r.json(), count + + +class BaseFilterRequestBuilder: + def __init__(self, session: Union[AsyncClient, SyncClient]): + self.session = session + self.negate_next = False + + @property + def not_(self): + self.negate_next = True + return self + + def filter(self, column: str, operator: str, criteria: str): + """Either filter in or filter out based on `self.negate_next.`""" + if self.negate_next is True: + self.negate_next = False + operator = f"{Filters.NOT}.{operator}" + key, val = sanitize_param(column), f"{operator}.{criteria}" + self.session.params = self.session.params.add(key, val) + return self + + def eq(self, column: str, value: str): + return self.filter(column, Filters.EQ, sanitize_param(value)) + + def neq(self, column: str, value: str): + return self.filter(column, Filters.NEQ, sanitize_param(value)) + + def gt(self, column: str, value: str): + return self.filter(column, Filters.GT, sanitize_param(value)) + + def gte(self, column: str, value: str): + return self.filter(column, Filters.GTE, sanitize_param(value)) + + def lt(self, column: str, value: str): + return self.filter(column, Filters.LT, sanitize_param(value)) + + def lte(self, column: str, value: str): + return self.filter(column, Filters.LTE, sanitize_param(value)) + + def is_(self, column: str, value: str): + return self.filter(column, Filters.IS, sanitize_param(value)) + + def like(self, column: str, pattern: str): + return self.filter(column, Filters.LIKE, sanitize_pattern_param(pattern)) + + def ilike(self, column: str, pattern: str): + return self.filter(column, Filters.ILIKE, sanitize_pattern_param(pattern)) + + def fts(self, column: str, query: str): + return self.filter(column, Filters.FTS, sanitize_param(query)) + + def plfts(self, column: str, query: str): + return self.filter(column, Filters.PLFTS, sanitize_param(query)) + + def phfts(self, column: str, query: str): + return self.filter(column, Filters.PHFTS, sanitize_param(query)) + + def wfts(self, column: str, query: str): + return self.filter(column, Filters.WFTS, sanitize_param(query)) + + def in_(self, column: str, values: Iterable[str]): + values = map(sanitize_param, values) + values = ",".join(values) + return self.filter(column, Filters.IN, f"({values})") + + def cs(self, column: str, values: Iterable[str]): + values = map(sanitize_param, values) + values = ",".join(values) + return self.filter(column, Filters.CS, f"{{{values}}}") + + def cd(self, column: str, values: Iterable[str]): + values = map(sanitize_param, values) + values = ",".join(values) + return self.filter(column, Filters.CD, f"{{{values}}}") + + def ov(self, column: str, values: Iterable[str]): + values = map(sanitize_param, values) + values = ",".join(values) + return self.filter(column, Filters.OV, f"{{{values}}}") + + def sl(self, column: str, range: Tuple[int, int]): + return self.filter(column, Filters.SL, f"({range[0]},{range[1]})") + + def sr(self, column: str, range: Tuple[int, int]): + return self.filter(column, Filters.SR, f"({range[0]},{range[1]})") + + def nxl(self, column: str, range: Tuple[int, int]): + return self.filter(column, Filters.NXL, f"({range[0]},{range[1]})") + + def nxr(self, column: str, range: Tuple[int, int]): + return self.filter(column, Filters.NXR, f"({range[0]},{range[1]})") + + def adj(self, column: str, range: Tuple[int, int]): + return self.filter(column, Filters.ADJ, f"({range[0]},{range[1]})") + + def match(self, query: Dict[str, Any]): + updated_query = None + for key in query: + value = query.get(key, "") + updated_query = self.eq(key, value) + return updated_query + + +class BaseSelectRequestBuilder(BaseFilterRequestBuilder): + def __init__(self, session: Union[AsyncClient, SyncClient]): + BaseFilterRequestBuilder.__init__(self, session) + + def order(self, column: str, *, desc=False, nullsfirst=False): + self.session.params[ + "order" + ] = f"{column}{'.desc' if desc else ''}{'.nullsfirst' if nullsfirst else ''}" + return self + + def limit(self, size: int, *, start=0): + self.session.headers["Range-Unit"] = "items" + self.session.headers["Range"] = f"{start}-{start + size - 1}" + return self + + def range(self, start: int, end: int): + self.session.headers["Range-Unit"] = "items" + self.session.headers["Range"] = f"{start}-{end - 1}" + return self + + def single(self): + self.session.headers["Accept"] = "application/vnd.pgrst.object+json" + return self diff --git a/postgrest_py/config.py b/postgrest_py/config.py deleted file mode 100644 index a2bb45be..00000000 --- a/postgrest_py/config.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Dict - -DEFAULT_POSTGREST_CLIENT_HEADERS: Dict[str, str] = { - "Accept": "application/json", - "Content-Type": "application/json", -} diff --git a/postgrest_py/constants.py b/postgrest_py/constants.py new file mode 100644 index 00000000..93c294e5 --- /dev/null +++ b/postgrest_py/constants.py @@ -0,0 +1,42 @@ +from enum import Enum + + +class CountMethod(str, Enum): + exact = "exact" + planned = "planned" + estimated = "estimated" + + +class Filters(str, Enum): + NOT = "not" + EQ = "eq" + NEQ = "neq" + GT = "gt" + GTE = "gte" + LT = "lt" + LTE = "lte" + IS = "is" + LIKE = "like" + ILIKE = "ilike" + FTS = "fts" + PLFTS = "plfts" + PHFTS = "phfts" + WFTS = "wfts" + IN = "in" + CS = "cs" + CD = "cd" + OV = "ov" + SL = "sl" + SR = "sr" + NXL = "nxl" + NXR = "nxr" + ADJ = "adj" + + +class RequestMethod(str, Enum): + GET = "GET" + POST = "POST" + PATCH = "PATCH" + PUT = "PUT" + DELETE = "DELETE" + HEAD = "HEAD" diff --git a/tests/_async/test_request_builder.py b/tests/_async/test_request_builder.py index 2ea19dbd..385a6337 100644 --- a/tests/_async/test_request_builder.py +++ b/tests/_async/test_request_builder.py @@ -1,6 +1,7 @@ import pytest from postgrest_py import AsyncRequestBuilder +from postgrest_py.constants import CountMethod from postgrest_py.utils import AsyncClient @@ -24,7 +25,7 @@ def test_select(self, request_builder: AsyncRequestBuilder): assert builder.json == {} def test_select_with_count(self, request_builder: AsyncRequestBuilder): - builder = request_builder.select(count="exact") + builder = request_builder.select(count=CountMethod.exact) assert builder.session.params.get("select") is None assert builder.session.headers["prefer"] == "count=exact" @@ -43,7 +44,7 @@ def test_insert(self, request_builder: AsyncRequestBuilder): assert builder.json == {"key1": "val1"} def test_insert_with_count(self, request_builder: AsyncRequestBuilder): - builder = request_builder.insert({"key1": "val1"}, count="exact") + builder = request_builder.insert({"key1": "val1"}, count=CountMethod.exact) assert builder.session.headers.get_list("prefer", True) == [ "return=representation", @@ -52,7 +53,7 @@ def test_insert_with_count(self, request_builder: AsyncRequestBuilder): assert builder.http_method == "POST" assert builder.json == {"key1": "val1"} - def test_upsert(self, request_builder: AsyncRequestBuilder): + def test_insert_with_upsert(self, request_builder: AsyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}, upsert=True) assert builder.session.headers.get_list("prefer", True) == [ @@ -62,6 +63,16 @@ def test_upsert(self, request_builder: AsyncRequestBuilder): assert builder.http_method == "POST" assert builder.json == {"key1": "val1"} + def test_upsert(self, request_builder: AsyncRequestBuilder): + builder = request_builder.upsert({"key1": "val1"}) + + assert builder.session.headers.get_list("prefer", True) == [ + "return=representation", + "resolution=merge-duplicates", + ] + assert builder.http_method == "POST" + assert builder.json == {"key1": "val1"} + class TestUpdate: def test_update(self, request_builder: AsyncRequestBuilder): @@ -74,7 +85,7 @@ def test_update(self, request_builder: AsyncRequestBuilder): assert builder.json == {"key1": "val1"} def test_update_with_count(self, request_builder: AsyncRequestBuilder): - builder = request_builder.update({"key1": "val1"}, count="exact") + builder = request_builder.update({"key1": "val1"}, count=CountMethod.exact) assert builder.session.headers.get_list("prefer", True) == [ "return=representation", @@ -95,7 +106,7 @@ def test_delete(self, request_builder: AsyncRequestBuilder): assert builder.json == {} def test_delete_with_count(self, request_builder: AsyncRequestBuilder): - builder = request_builder.delete(count="exact") + builder = request_builder.delete(count=CountMethod.exact) assert builder.session.headers.get_list("prefer", True) == [ "return=representation", diff --git a/tests/_sync/test_request_builder.py b/tests/_sync/test_request_builder.py index 61a53375..dac03efe 100644 --- a/tests/_sync/test_request_builder.py +++ b/tests/_sync/test_request_builder.py @@ -1,6 +1,7 @@ import pytest from postgrest_py import SyncRequestBuilder +from postgrest_py.constants import CountMethod from postgrest_py.utils import SyncClient @@ -24,7 +25,7 @@ def test_select(self, request_builder: SyncRequestBuilder): assert builder.json == {} def test_select_with_count(self, request_builder: SyncRequestBuilder): - builder = request_builder.select(count="exact") + builder = request_builder.select(count=CountMethod.exact) assert builder.session.params.get("select") is None assert builder.session.headers["prefer"] == "count=exact" @@ -43,7 +44,7 @@ def test_insert(self, request_builder: SyncRequestBuilder): assert builder.json == {"key1": "val1"} def test_insert_with_count(self, request_builder: SyncRequestBuilder): - builder = request_builder.insert({"key1": "val1"}, count="exact") + builder = request_builder.insert({"key1": "val1"}, count=CountMethod.exact) assert builder.session.headers.get_list("prefer", True) == [ "return=representation", @@ -52,7 +53,7 @@ def test_insert_with_count(self, request_builder: SyncRequestBuilder): assert builder.http_method == "POST" assert builder.json == {"key1": "val1"} - def test_upsert(self, request_builder: SyncRequestBuilder): + def test_insert_with_upsert(self, request_builder: SyncRequestBuilder): builder = request_builder.insert({"key1": "val1"}, upsert=True) assert builder.session.headers.get_list("prefer", True) == [ @@ -62,6 +63,16 @@ def test_upsert(self, request_builder: SyncRequestBuilder): assert builder.http_method == "POST" assert builder.json == {"key1": "val1"} + def test_upsert(self, request_builder: SyncRequestBuilder): + builder = request_builder.upsert({"key1": "val1"}) + + assert builder.session.headers.get_list("prefer", True) == [ + "return=representation", + "resolution=merge-duplicates", + ] + assert builder.http_method == "POST" + assert builder.json == {"key1": "val1"} + class TestUpdate: def test_update(self, request_builder: SyncRequestBuilder): @@ -74,7 +85,7 @@ def test_update(self, request_builder: SyncRequestBuilder): assert builder.json == {"key1": "val1"} def test_update_with_count(self, request_builder: SyncRequestBuilder): - builder = request_builder.update({"key1": "val1"}, count="exact") + builder = request_builder.update({"key1": "val1"}, count=CountMethod.exact) assert builder.session.headers.get_list("prefer", True) == [ "return=representation", @@ -95,7 +106,7 @@ def test_delete(self, request_builder: SyncRequestBuilder): assert builder.json == {} def test_delete_with_count(self, request_builder: SyncRequestBuilder): - builder = request_builder.delete(count="exact") + builder = request_builder.delete(count=CountMethod.exact) assert builder.session.headers.get_list("prefer", True) == [ "return=representation",