pyspark.sql.DataFrame.mapInArrow

DataFrame.mapInArrow(func: ArrowMapIterFunction, schema: Union[pyspark.sql.types.StructType, str]) → DataFrame

Maps an iterator of batches in the current DataFrame using a Python native function that takes and outputs a PyArrow’s RecordBatch, and returns the result as a DataFrame.

The function should take an iterator of pyarrow.RecordBatchs and return another iterator of pyarrow.RecordBatchs. All columns are passed together as an iterator of pyarrow.RecordBatchs to the function and the returned iterator of pyarrow.RecordBatchs are combined as a DataFrame. Each pyarrow.RecordBatch size can be controlled by spark.sql.execution.arrow.maxRecordsPerBatch.

Parameters
funcfunction

a Python native function that takes an iterator of pyarrow.RecordBatchs, and outputs an iterator of pyarrow.RecordBatchs.

schemapyspark.sql.types.DataType or str

the return type of the func in PySpark. The value can be either a pyspark.sql.types.DataType object or a DDL-formatted type string.

Notes

This API is unstable, and for developers.

Examples

>>> import pyarrow  
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
>>> def filter_func(iterator):
...     for batch in iterator:
...         pdf = batch.to_pandas()
...         yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
>>> df.mapInArrow(filter_func, df.schema).show()  
+---+---+
| id|age|
+---+---+
|  1| 21|
+---+---+