pyspark.sql.DataFrame.observe

DataFrame.observe(observation: Union[Observation, str], *exprs: pyspark.sql.column.Column) → DataFrame

Define (named) metrics to observe on the DataFrame. This method returns an ‘observed’ DataFrame that returns the same result as the input, with the following guarantees:

  • It will compute the defined aggregates (metrics) on all the data that is flowing through

    the Dataset at that point.

  • It will report the value of the defined aggregate columns as soon as we reach a completion

    point. A completion point is either the end of a query (batch mode) or the end of a streaming epoch. The value of the aggregates only reflects the data processed since the previous completion point.

The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that contain references to the input Dataset’s columns must always be wrapped in an aggregate function.

A user can observe these metrics by adding Python’s StreamingQueryListener, Scala/Java’s org.apache.spark.sql.streaming.StreamingQueryListener or Scala/Java’s org.apache.spark.sql.util.QueryExecutionListener to the spark session.

Parameters
observationObservation or str

str to specify the name, or an Observation instance to obtain the metric.

Added support for str in this parameter.

exprsColumn

column expressions (Column).

Returns
DataFrame

the observed DataFrame.

Notes

When observation is Observation, this method only supports batch queries. When observation is a string, this method works for both batch and streaming queries. Continuous execution is currently not supported yet.

Examples

When observation is Observation, only batch queries works as below.

>>> from pyspark.sql.functions import col, count, lit, max
>>> from pyspark.sql import Observation
>>> observation = Observation("my metrics")
>>> observed_df = df.observe(observation, count(lit(1)).alias("count"), max(col("age")))
>>> observed_df.count()
2
>>> observation.get
{'count': 2, 'max(age)': 5}

When observation is a string, streaming queries also work as below.

>>> from pyspark.sql.streaming import StreamingQueryListener
>>> class MyErrorListener(StreamingQueryListener):
...    def onQueryStarted(self, event):
...        pass
...
...    def onQueryProgress(self, event):
...        row = event.progress.observedMetrics.get("my_event")
...        # Trigger if the number of errors exceeds 5 percent
...        num_rows = row.rc
...        num_error_rows = row.erc
...        ratio = num_error_rows / num_rows
...        if ratio > 0.05:
...            # Trigger alert
...            pass
...
...    def onQueryTerminated(self, event):
...        pass
...
>>> spark.streams.addListener(MyErrorListener())
>>> # Observe row count (rc) and error row count (erc) in the streaming Dataset
... observed_ds = df.observe(
...     "my_event",
...     count(lit(1)).alias("rc"),
...     count(col("error")).alias("erc"))  
>>> observed_ds.writeStream.format("console").start()