DataFrame.observe(observation: Union[Observation, str], *exprs: pyspark.sql.column.Column) → DataFrame

Define (named) metrics to observe on the DataFrame. This method returns an ‘observed’ DataFrame that returns the same result as the input, with the following guarantees:

  • It will compute the defined aggregates (metrics) on all the data that is flowing through

    the Dataset at that point.

  • It will report the value of the defined aggregate columns as soon as we reach a completion

    point. A completion point is either the end of a query (batch mode) or the end of a streaming epoch. The value of the aggregates only reflects the data processed since the previous completion point.

The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that contain references to the input Dataset’s columns must always be wrapped in an aggregate function.

A user can observe these metrics by adding Python’s StreamingQueryListener, Scala/Java’s org.apache.spark.sql.streaming.StreamingQueryListener or Scala/Java’s org.apache.spark.sql.util.QueryExecutionListener to the spark session.

observationObservation or str

str to specify the name, or an Observation instance to obtain the metric.

Added support for str in this parameter.


column expressions (Column).


the observed DataFrame.


When observation is Observation, this method only supports batch queries. When observation is a string, this method works for both batch and streaming queries. Continuous execution is currently not supported yet.


When observation is Observation, only batch queries works as below.

>>> from pyspark.sql.functions import col, count, lit, max
>>> from pyspark.sql import Observation
>>> observation = Observation("my metrics")
>>> observed_df = df.observe(observation, count(lit(1)).alias("count"), max(col("age")))
>>> observed_df.count()
>>> observation.get
{'count': 2, 'max(age)': 5}

When observation is a string, streaming queries also work as below.

>>> from pyspark.sql.streaming import StreamingQueryListener
>>> class MyErrorListener(StreamingQueryListener):
...    def onQueryStarted(self, event):
...        pass
...    def onQueryProgress(self, event):
...        row = event.progress.observedMetrics.get("my_event")
...        # Trigger if the number of errors exceeds 5 percent
...        num_rows = row.rc
...        num_error_rows = row.erc
...        ratio = num_error_rows / num_rows
...        if ratio > 0.05:
...            # Trigger alert
...            pass
...    def onQueryTerminated(self, event):
...        pass
>>> spark.streams.addListener(MyErrorListener())
>>> # Observe row count (rc) and error row count (erc) in the streaming Dataset
... observed_ds = df.observe(
...     "my_event",
...     count(lit(1)).alias("rc"),
...     count(col("error")).alias("erc"))  
>>> observed_ds.writeStream.format("console").start()