Skip to content

Commit 264ccfc

Browse files
committed
address comments and change example to prime numbers source
1 parent 3cbabc2 commit 264ccfc

File tree

2 files changed

+86
-30
lines changed

2 files changed

+86
-30
lines changed

python/docs/source/user_guide/sql/python_data_source.rst

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ Filter Pushdown in Python Data Sources
535535

536536
Filter pushdown is an optimization technique that allows data sources to handle filters natively, reducing the amount of data that needs to be transferred and processed by Spark.
537537

538-
The filter pushdown API is introduced in Spark 4.1, enabling DataSourceReader to selectively push down filters from the query to the source.
538+
The filter pushdown API enables DataSourceReader to selectively push down filters from the query to the source.
539539

540540
You must turn on the configuration ``spark.sql.python.filterPushdown.enabled`` to enable filter pushdown.
541541

@@ -554,32 +554,93 @@ To enable filter pushdown in your Python Data Source, implement the ``pushFilter
554554

555555
.. code-block:: python
556556
557-
from pyspark.sql.datasource import EqualTo, Filter, GreaterThan, LessThan
558-
559-
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
557+
import math
558+
from typing import Iterable, List
559+
from pyspark.sql.datasource import (
560+
DataSource,
561+
DataSourceReader,
562+
EqualTo,
563+
Filter,
564+
GreaterThan,
565+
LessThan,
566+
GreaterThanOrEqual,
567+
LessThanOrEqual,
568+
)
569+
570+
571+
class PrimesDataSource(DataSource):
560572
"""
561-
Parameters
562-
----------
563-
filters : list of Filter objects
564-
The AND of the filters that Spark would like to push down
565-
566-
Returns
567-
-------
568-
iterable of Filter objects
569-
Filters that could not be pushed down and still need to be
570-
evaluated by Spark
573+
A data source that enumerates prime numbers.
571574
"""
572-
# Process the filters and determine which ones can be handled by the data source
573-
pushed = []
574-
for filter in filters:
575-
if isinstance(filter, (EqualTo, GreaterThan, LessThan)):
576-
pushed.append(filter)
577-
# Check for other supported filter types...
578-
else:
579-
yield filter # Let Spark handle unsupported filters
580-
581-
# Store the pushed filters for use in partitions() and read() methods
582-
self.pushed_filters = pushed
575+
576+
@classmethod
577+
def name(cls):
578+
return "primes"
579+
580+
def schema(self):
581+
return "p int"
582+
583+
def reader(self, schema: str):
584+
return PrimesDataSourceReader(schema, self.options)
585+
586+
587+
class PrimesDataSourceReader(DataSourceReader):
588+
def __init__(self, schema, options):
589+
self.schema: str = schema
590+
self.options = options
591+
self.lower_bound = 2
592+
self.upper_bound = math.inf
593+
594+
def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
595+
"""
596+
Parameters
597+
----------
598+
filters : list of Filter objects
599+
The AND of the filters that Spark would like to push down
600+
601+
Returns
602+
-------
603+
iterable of Filter objects
604+
Filters that could not be pushed down and still need to be
605+
evaluated by Spark
606+
"""
607+
for filter in filters:
608+
print(f"Got filter: {filter}")
609+
if isinstance(filter, EqualTo):
610+
self.lower_bound = max(self.lower_bound, filter.value)
611+
self.upper_bound = min(self.upper_bound, filter.value)
612+
elif isinstance(filter, GreaterThan):
613+
self.lower_bound = max(self.lower_bound, filter.value + 1)
614+
elif isinstance(filter, LessThan):
615+
self.upper_bound = min(self.upper_bound, filter.value - 1)
616+
elif isinstance(filter, GreaterThanOrEqual):
617+
self.lower_bound = max(self.lower_bound, filter.value)
618+
elif isinstance(filter, LessThanOrEqual):
619+
self.upper_bound = min(self.upper_bound, filter.value)
620+
else:
621+
yield filter # Let Spark handle unsupported filters
622+
623+
def read(self, partition):
624+
# Use the pushed filters to filter data during read
625+
num = self.lower_bound
626+
while num <= self.upper_bound:
627+
if self._is_prime(num):
628+
yield [num]
629+
num += 1
630+
631+
@staticmethod
632+
def _is_prime(n: int) -> bool:
633+
"""Check if a number is prime."""
634+
if n < 2:
635+
return False
636+
for i in range(2, int(n**0.5) + 1):
637+
if n % i == 0:
638+
return False
639+
return True
640+
641+
# Register the data source
642+
spark.dataSource.register(PrimesDataSource)
643+
spark.read.format("primes").load().filter("2000 <= p and p < 2050").show()
583644
584645
**Notes**
585646

python/pyspark/sql/datasource.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,11 +539,6 @@ def pushFilters(self, filters: List["Filter"]) -> Iterable["Filter"]:
539539
This method is allowed to modify `self`. The object must remain picklable.
540540
Modifications to `self` are visible to the `partitions()` and `read()` methods.
541541
542-
Notes
543-
-----
544-
Configuration `spark.sql.python.filterPushdown.enabled` must be set to `true`
545-
to implement this method.
546-
547542
Examples
548543
--------
549544
Example filters and the resulting arguments passed to pushFilters:

0 commit comments

Comments
 (0)