Source code for deephaven_enterprise.pivot.client.pivot_client

"""Pivot table client implementation for Deephaven Enterprise.

This module provides classes and functionality for creating and managing pivot tables over gRPC.

The :class:`PivotClient` creates new PivotTables from your `pydeephaven.session`.  Once you have a
:class:`PivotTable`, you can apply changes or call :meth:`PivotTable.view` to create a
:class:`PivotTableSubscription`.  The subscription provides a :class:`PivotSnapshot` of your
viewport on each change, via a :class:`PivotListener` passed to :meth:`PivotTableSubscription.listen` method.
"""

from __future__ import annotations

import threading
from abc import ABC, abstractmethod
from typing import Optional, Union

import pyarrow
import pydeephaven.experimental.plugin_client
import pydeephaven.ticket
from deephaven_enterprise.pivot.flatbuf import PivotBarrageMessageType
from deephaven_enterprise.pivot.flatbuf.BarrageMessageWrapper import (
    BarrageMessageWrapper,
)
from deephaven_enterprise.pivot.flatbuf.PivotUpdateMetadata import PivotUpdateMetadata
from deephaven_enterprise.proto.pivot.pivottable_pb2 import PivotTableDescriptorMessage
from pyarrow.flight import FlightCallOptions
from pydeephaven import DHError, Session
from pydeephaven.experimental.plugin_client import Fetchable, PluginClient

from ._barrage import (
    BARRAGE_MAGIC,
    encode_barrage_subscription_request,
)
from ._plugin_rpc import (
    RequestTracker,
    apply_to_pivot,
    create_pivot,
    empty_expansions,
    view,
)
from ._rowset import decode_rowset


[docs] class PivotSnapshot: """A snapshot of pivot table data including row/column information and values. This class encapsulates all the data for a specific state of a pivot table, including row and column metadata, values, and viewport information. """ _batch: pyarrow.RecordBatch _row_depth: int _row_expanded: int _row_keys: list[int] _row_totals: list[int] _column_depth: int _column_expanded: int _column_keys: list[int] _column_totals: list[int] _values: list[int] _visible_row_count: int _visible_col_count: int _total_row_count: int _total_col_count: int _row_viewport: Union[tuple[()], tuple[int, int]] _columns_viewport: Union[tuple[()], tuple[int, int]] def __init__(self, batch: pyarrow.RecordBatch, md_bytes: bytes): """Initialize a new PivotSnapshot. Args: batch: PyArrow RecordBatch containing the pivot table data. md_bytes: Metadata bytes containing viewport information. Raises: DHError: If required pivot table schema columns are missing or invalid. """ self._batch = batch schema = batch.schema self._row_keys = [] self._row_totals = [] self._column_keys = [] self._column_totals = [] self._values = [] self._grand_totals = [] bmw = BarrageMessageWrapper.GetRootAsBarrageMessageWrapper(md_bytes) if ( bmw.Magic() == BARRAGE_MAGIC and bmw.MsgType() == PivotBarrageMessageType.PivotBarrageMessageType.PivotUpdateMetadata ): pum = PivotUpdateMetadata.GetRootAs(bmw.MsgPayloadAsNumpy().tobytes()) self._total_row_count = pum.RowsSize() self._total_col_count = pum.ColumnsSize() self._row_viewport = decode_rowset( pum.EffectiveRowViewportAsNumpy().tobytes() ) self._columns_viewport = decode_rowset( pum.EffectiveColumnViewportAsNumpy().tobytes() ) row_depth = None row_expanded = None column_depth = None column_expanded = None for cc in range(0, len(schema)): column_metadata = schema[cc].metadata if b"deephaven:pivotTable.isRowDepthColumn" in column_metadata: row_depth = cc elif b"deephaven:pivotTable.isRowExpandedColumn" in column_metadata: row_expanded = cc elif b"deephaven:pivotTable.isRowGroupByColumn" in column_metadata: self._row_keys.append(cc) elif b"deephaven:pivotTable.isRowValueColumn" in column_metadata: self._row_totals.append(cc) elif b"deephaven:pivotTable.isColumnDepthColumn" in column_metadata: column_depth = cc elif b"deephaven:pivotTable.isColumnExpandedColumn" in column_metadata: column_expanded = cc elif b"deephaven:pivotTable.isColumnGroupByColumn" in column_metadata: self._column_keys.append(cc) elif b"deephaven:pivotTable.isColumnValueColumn" in column_metadata: self._column_totals.append(cc) elif b"deephaven:pivotTable.isValueColumn" in column_metadata: self._values.append(cc) elif b"deephaven:pivotTable.isGrandTotalValueColumn" in column_metadata: self._grand_totals.append(cc) else: print("Unknown column: ", cc, schema[cc]) if ( row_depth is None or column_depth is None or row_expanded is None or column_expanded is None ): raise DHError( "Invalid pivot table, row depth, column depth, row expanded, column expanded must be set." ) # The row totals names, column totals names, grand totals names, andd values names must all match check_names = self.value_names gtnames = list( map(lambda idx: self._batch.schema[idx].name, self._grand_totals) ) if gtnames != check_names: raise DHError( f"Grand totals names do not match value names: {gtnames} != {check_names}" ) rtnames = list(map(lambda idx: self._batch.schema[idx].name, self._row_totals)) if rtnames != check_names: raise DHError( f"Row total names do not match value names: {rtnames} != {check_names}" ) ctnames = list( map(lambda idx: self._batch.schema[idx].name, self._column_totals) ) if ctnames != check_names: raise DHError( f"Column total names do not match value names: {ctnames} != {check_names}" ) self._row_depth = row_depth self._row_expanded = row_expanded self._column_depth = column_depth self._column_expanded = column_expanded self._visible_row_count = len(batch[self._row_depth][0]) self._visible_col_count = len(batch[self._column_depth][0]) @property def total_row_count(self) -> int: """Get the total number of rows in the pivot table given the current expansions. The number of visible rows may be less than this value, based on the row viewport. If the rows were fully expanded, then there may be more rows. Returns: int: Total number of rows given the current expansions. """ return self._total_row_count @property def total_col_count(self) -> int: """Get the total number of columnsin the pivot table given the current expansions. The number of visible columns may be less than this value, based on the column viewport. If the columns were fully expanded, then there may be more columns. Returns: int: Total number columns given the current expansions. """ return self._total_col_count @property def num_columns(self) -> int: """Get the number of visible columns in the pivot table. Returns: int: Number of visible columns. """ return self._visible_col_count @property def num_rows(self) -> int: """Get the number of visible rows in the pivot table. Returns: int: Number of visible rows. """ return self._visible_row_count @property def num_values(self) -> int: """Get the number of value columns in the pivot table. The row, column, and grand totals have the same counts and names. Returns: int: Number of value columns. """ return len(self._values) @property def row_key_names(self) -> list[str]: """Get the names of the row key columns. Returns: list[str]: List of row key column names. """ return list(map(lambda idx: self._batch.schema[idx].name, self._row_keys)) @property def column_key_names(self) -> list[str]: """Get the names of the column key columns. Returns: list[str]: List of column key column names. """ return list(map(lambda idx: self._batch.schema[idx].name, self._column_keys)) def _row_key_indices(self) -> list[int]: """Get the indices of the row key columns. Returns: list[int]: List of row key column indices. """ return self._row_keys def _column_key_indices(self) -> list[int]: """Get the indices of the column key columns. Returns: list[int]: List of column key column indices. """ return self._column_keys @property def value_names(self) -> list[str]: """Get the names of the value columns. The row, column, and grand totals have the same counts and names. Returns: list[str]: List of value column names. """ return list(map(lambda idx: self._batch.schema[idx].name, self._values)) def _value_indices(self) -> list[int]: """Get the indices of the value columns. Returns: list[int]: List of value column indices. """ return self._values def _row_totals_indices(self) -> list[int]: """Get the indices of the row total columns. Returns: list[int]: List of row total column indices. """ return self._row_totals def _column_totals_indices(self) -> list[int]: """Get the indices of the column total columns. Returns: list[int]: List of column total column indices. """ return self._column_totals def _grand_totals_indices(self) -> list[int]: """Get the indices of the grand total columns. Returns: list[int]: List of grand total column indices. """ return self._grand_totals
[docs] def grand_totals(self) -> pyarrow.Table: """Get the grand totals as a PyArrow Table. Returns: pyarrow.Table: Table containing the grand total values. """ gtvals = list( map(lambda gti: self._batch[gti][0], self._grand_totals_indices()) ) return pyarrow.Table.from_arrays(gtvals, names=self.value_names)
[docs] def values(self) -> pyarrow.Table: """Get the field of values as a PyArrow Table. Each value is stored in column major order and forms a grid of :meth:`num_visible_rows` by :meth:`num_visible_cols` values. Returns: pyarrow.Table: Table containing the values. """ values_field = list(map(lambda vi: self._batch[vi][0], self._value_indices())) return pyarrow.Table.from_arrays(values_field, names=self.value_names)
[docs] def row_expanded(self) -> pyarrow.Array: """Get the row expanded states as a PyArrow Array. Returns: pyarrow.Array: Array of row expanded states. """ return pyarrow.array(self._batch[self._row_expanded][0])
[docs] def row_depths(self) -> pyarrow.Array: """Get the row depths as a PyArrow Array. Returns: pyarrow.Array: Array of row depths. """ return pyarrow.array(self._batch[self._row_depth][0])
[docs] def row_keys(self) -> pyarrow.Table: """Get the row keys as a PyArrow Table. Returns: pyarrow.Table: Table containing the row keys. """ row_key_values = list( map(lambda rki: self._batch[rki][0], self._row_key_indices()) ) return pyarrow.Table.from_arrays(row_key_values, names=self.row_key_names)
[docs] def col_expanded(self) -> pyarrow.Array: """Get the column expanded states as a PyArrow Array. Returns: pyarrow.Array: Array of column expanded states. """ return pyarrow.array(self._batch[self._column_expanded][0])
[docs] def col_depths(self) -> pyarrow.Array: """Get the column depths as a PyArrow Array. Returns: pyarrow.Array: Array of column depths. """ return pyarrow.array(self._batch[self._column_depth][0])
[docs] def column_keys(self) -> pyarrow.Table: """Get the column keys as a PyArrow Table. Returns: pyarrow.Table: Table containing the column keys. """ column_key_values = list( map(lambda cki: self._batch[cki][0], self._column_key_indices()) ) return pyarrow.Table.from_arrays(column_key_values, names=self.column_key_names)
[docs] def row_totals(self) -> pyarrow.Table: """Get the row totals as a PyArrow Table. Returns: pyarrow.Table: Table containing the row totals. """ row_totals_values = list( map(lambda rti: self._batch[rti][0], self._row_totals_indices()) ) return pyarrow.Table.from_arrays(row_totals_values, names=self.value_names)
[docs] def column_totals(self) -> pyarrow.Table: """Get the column totals as a PyArrow Table. Returns: pyarrow.Table: Table containing the column totals. """ column_totals_values = list( map(lambda cti: self._batch[cti][0], self._column_totals_indices()) ) return pyarrow.Table.from_arrays(column_totals_values, names=self.value_names)
[docs] def batch(self) -> pyarrow.RecordBatch: """Get the raw PyArrow RecordBatch. Returns: pyarrow.RecordBatch: The raw RecordBatch containing the data. """ return self._batch
def _descriptor_from_view(ptv): """Creates a FlightDescriptor from a pivot table view. Args: ptv: Pivot table view to create descriptor from. Returns: pyarrow.flight.FlightDescriptor: Flight descriptor for the export path. Raises: DHError: If the ticket type is not an export ticket. """ if ptv.ticket.bytes[0] != "e".encode("ascii")[0]: raise DHError("Expected export ticket, but type was " + ptv.ticket.bytes[0]) return pyarrow.flight.FlightDescriptor.for_path( "export", str(int.from_bytes(ptv.ticket.bytes[1:4], "little", signed=True)) )
[docs] class PivotListener(ABC): """Abstract base class for pivot table snapshot listeners. Implement this class to receive a notification on each new pivot table snapshot. Listeners must be registered with a call to :meth:`PivotTableSubscription.listen`; and may be removed with a call to :meth:`PivotTableSubscription.remove_listener`. """
[docs] @abstractmethod def on_snapshot(self, snap: PivotSnapshot): """Called when a new pivot table snapshot is available. Args: snap: New PivotSnapshot. """ pass
[docs] @abstractmethod def on_complete(self, error: Optional[Exception]): """Called when the listener will receive no further updates from this subscription. If an error occurred, then error is populated, otherwise None.""" pass
[docs] class PivotTableSubscription: """Manages a subscription to pivot table snapshots. Handles the communication channel for receiving pivot table updates and manages listeners for update notifications. """ _server_object: pydeephaven.ticket.ServerObject _listeners: list[PivotListener] = [] _snap = None _reader: pyarrow.flight.FlightStreamReader _writer: pyarrow.flight.FlightStreamWriter _session: pydeephaven.session.Session _completed = False class _ReaderThread(threading.Thread): def __init__( self, sub: PivotTableSubscription, ): threading.Thread.__init__(self) self.stop = False self.daemon = True self.sub = sub def run(self): while not self.stop: try: chunk = self.sub._reader.read_chunk() except StopIteration: # If the stream is ended, we are going to end w/o Error. break except Exception as e: if self.stop: # If we were asked to Stop, then we should just end naturally break # If there was any other error, we will notify our subscriber of that Error. self.sub._complete(e) self.sub._reader.cancel() return pybytes = chunk.app_metadata.to_pybytes() snap = PivotSnapshot(chunk.data, pybytes) if not self.sub._set_snap(snap): return self.sub._complete(None) def __init__( self, session: pydeephaven.Session, server_object: pydeephaven.ticket.ServerObject, row_expansions: pydeephaven.Table, col_expansions: pydeephaven.Table, initial_rows: Optional[tuple[int, int]] = None, initial_cols: Optional[tuple[int, int]] = None, ): """Initialize a new pivot table subscription. Args: session: Deephaven session. server_object: Server-side pivot table object. row_expansions: Table containing row expansion states. col_expansions: Table containing column expansion states. initial_rows: Optional tuple of (start, end) for initial row viewport. initial_cols: Optional tuple of (start, end) for initial column viewport. """ self._server_object = server_object self._row_expansions = row_expansions self._col_expansions = col_expansions self._session = session descriptor = _descriptor_from_view(server_object) fc: pyarrow.flight.FlightClient = session.flight_service._flight_client (writer, reader) = fc.do_exchange( descriptor, FlightCallOptions(headers=session.grpc_metadata) ) self._reader = reader self._writer = writer if initial_rows is None: initial_rows = (-1, -1) if initial_cols is None: initial_cols = (-1, -1) writer.write_metadata( encode_barrage_subscription_request( server_object.ticket.bytes, initial_rows, initial_cols ) ) self._rt = PivotTableSubscription._ReaderThread(self) self._rt.start()
[docs] def add_listener(self, listener: PivotListener): """Register a new pivot table snapshot listener. If the subscription has already been completed, then an exception is raised. The listener is called once with the most recent snapshot state (if available). If a listener raises an Exception while processing the initial PivotSnapshot in the :method:`on_snapshot` method, then the Listener is not added, and the Exception is re-raised. After the initial PivotSnapshot, if a listener's on_complete method raises an Exception, the exception is ignored. If a listener's on_snapshot method raises an Exception, then the on_complete method is called and the subscription is closed. Args: listener: PivotListener instance to register. """ if self._completed: raise DHError("Subscription has already been completed") if self._snap is not None: try: listener.on_snapshot(self._snap) except Exception as e: self.close() raise DHError("Listener failed to process initial PivotSnapshot") from e self._listeners.append(listener)
[docs] def remove_listener(self, listener: PivotListener): """Remove a registered pivot table snapshot listener. Args: listener: PivotListener instance to remove. """ self._listeners.remove(listener)
def _set_snap(self, snap: PivotSnapshot): """Update the current snapshot and notify listeners. If the listeners are not called successfully, then the subscription is closed and False is returned Args: snap: New PivotSnapshot instance. Returns: True if the listeners completed successfully, False otherwise. """ self._snap = snap for ll in self._listeners: try: ll.on_snapshot(snap) except Exception as e: self._complete(DHError("Failed to process PivotSnapshot" + str(e))) self.close() return False return True def _complete(self, error: Optional[Exception]): """Signal completion to all listeners.""" self._snap = None if self._completed: raise DHError("Listener has already been completed") self._completed = True for ll in self._listeners: try: ll.on_complete(error) except Exception: # We can't do anything with a completion failure, that is how we would notify of completion anyway pass
[docs] def close(self): """Close the subscription and receive no further updates.""" self._rt.stop = True self._reader.cancel() if self._rt != threading.current_thread(): self._rt.join()
[docs] def row_expansions(self) -> pydeephaven.Table: """Get the table of row expansion states. If the row expansion table is an input table, then any modifications are reflected in the snapshots for this subscription. Returns: pydeephaven.Table: Table containing row expansion states. """ return self._row_expansions
[docs] def col_expansions(self): """Get the table of column expansion states. If the column expansion table is an input table, then any modifications are reflected in the snapshots for this subscription. Returns: pydeephaven.Table: Table containing column expansion states. """ return self._col_expansions
[docs] def set_viewport(self, rows: tuple[int, int], cols: tuple[int, int]): """Update the viewport for rows and columns. Note that totals are always sent, so the row and columns viewports only account for the field of values. Additionally, the column viewport represents the number of column keys that are visible. When multiple values are requested, you may choose to display those values in distinct columns (as is done in the :method:`deephaven_enterprise.pivot.client.formatter.format_grid` method), and your column viewport must be adjusted accordingly. Args: rows: Tuple of (start, end) for row viewport. cols: Tuple of (start, end) for column viewport. """ self._writer.write_metadata( encode_barrage_subscription_request( self._server_object.ticket.bytes, rows, cols ) )
[docs] class PivotTable: """Represents a pivot table. Provides interface for applying operations to pivot tables and creating subscriptions for snapshots. """ _server_object: pydeephaven.ticket.ServerObject _pivot_client: "PivotClient" _descriptor: PivotTableDescriptorMessage _outputs: list[str] _schema: pyarrow.Schema def __init__( self, server_object: pydeephaven.ticket.ServerObject, descriptor: PivotTableDescriptorMessage, pivot_client: "PivotClient", ): """Initialize a new pivot table. Args: server_object: Server-side pivot table object. descriptor: Descriptor message containing table information. pivot_client: Associated PivotClient instance. """ self._server_object = server_object self._pivot_client = pivot_client self._descriptor = descriptor snapshot_schema = self._descriptor.snapshot_schema with pyarrow.BufferReader(snapshot_schema) as reader: stream = pyarrow.ipc.open_stream(reader) self._schema = stream.schema self._outputs = [] for cc in range(0, len(self._schema)): column_metadata = self._schema[cc].metadata if b"deephaven:pivotTable.isValueColumn" in column_metadata: self._outputs.append(self._schema[cc].name)
[docs] def outputs(self) -> list[str]: """Get the list of output column names. Returns: list[str]: List of output column names. """ return self._outputs
[docs] def refreshing(self) -> bool: """Check if this pivot table is refreshing. Returns: bool: True if the table is refreshing, False if static. """ return not self._descriptor.is_static
[docs] def apply( self, row_sorts: list[str], col_sorts: list[str], filters: list[str] ) -> "PivotTable": """Apply sorting and filtering operations to the pivot table. Args: row_sorts: List of row sort expressions. col_sorts: List of column sort expressions. filters: List of filter expressions. Returns: PivotTable: New pivot table with applied operations. """ (pivot, descriptor) = apply_to_pivot( self._pivot_client.plugin_client, self._pivot_client.tracker, self._server_object, row_sorts, col_sorts, filters, ) return PivotTable( pivot, descriptor, self._pivot_client, )
[docs] def empty_expansions(self) -> tuple[Fetchable, Fetchable]: """Get empty expansion tables for rows and columns. This table can be used as a prototype for the row and expansion tables for the :method:`view` method. Returns: tuple: (row_expansions, col_expansions) tables. """ return empty_expansions( self._pivot_client.plugin_client, self._pivot_client.tracker, self._server_object, )
[docs] def view( self, outputs: Optional[list[str]] = None, row_expansions: Optional[pydeephaven.Table] = None, col_expansions: Optional[pydeephaven.Table] = None, initial_rows: Optional[tuple[int, int]] = None, initial_cols: Optional[tuple[int, int]] = None, ) -> PivotTableSubscription: """Create a subscription view of the pivot table. Args: outputs: Optional list of output columns to include. row_expansions: Optional table of row expansion states. col_expansions: Optional table of column expansion states. initial_rows: Optional tuple of (start, end) for initial row viewport. initial_cols: Optional tuple of (start, end) for initial column viewport. Returns: PivotTableSubscription: New subscription to pivot table updates. """ if row_expansions is None or col_expansions is None: # Get empty expansions (row_exp, col_exp) = self.empty_expansions() row_exp_table: Union[pydeephaven.Table, PluginClient] = row_exp.fetch() if not isinstance(row_exp_table, pydeephaven.Table): raise TypeError( "row_exp_table is not a table: " + str(type(row_exp_table)) ) if row_expansions is None: row_expansions = self._pivot_client.session.input_table( init_table=row_exp_table.update_view("Action=(byte)0"), key_cols=row_exp_table.schema.names, # type: ignore[union-attr] ) if col_expansions is None: col_exp_table: Union[pydeephaven.Table, PluginClient] = col_exp.fetch() if not isinstance(col_exp_table, pydeephaven.Table): raise TypeError( "col_exp_table is not a table: " + str(type(row_exp_table)) ) col_expansions = self._pivot_client.session.input_table( init_table=col_exp_table.update_view("Action=(byte)0"), key_cols=col_exp_table.schema.names, # type: ignore[union-attr] ) server_view_object = view( self._pivot_client.plugin_client, self._pivot_client.tracker, self._server_object, row_expansions, col_expansions, self._outputs if outputs is None else outputs, ) return PivotTableSubscription( self._pivot_client.session, server_view_object, row_expansions, col_expansions, initial_rows, initial_cols, )
[docs] class PivotClient: """Client interface for creating pivot tables in Deephaven Enterprise. Attributes: session: The Deephaven Session this client is attached to. plugin_client: Plugin client for pivot table operations. """ # The Session we are attached to session: Session plugin_client: pydeephaven.experimental.plugin_client.PluginClient def __init__(self, session: Session, plugin_name=None, ticket=None): """Initialize a new PivotClient. Args: session: Deephaven Session to use for operations. plugin_name: Optional name of the plugin to use (scope variable). ticket: Optional ticket for the plugin. Raises: RuntimeError: If plugin_client initialization fails. """ self.session = session if ticket is not None: psp = ticket elif plugin_name is None: raise DHError("plugin_name or ticket is required!") else: psp = session.exportable_objects["psp"] self.plugin_client = session.plugin_client(psp) if self.plugin_client is None: raise RuntimeError("plugin_client is None.") self.tracker = RequestTracker() response_payload, response_refs = next(self.plugin_client.resp_stream) if len(response_payload) != 0: raise RuntimeError( "Non empty payload first response, payload=" + f"{response_payload!r}" )
[docs] def create_pivot(self, row_by_keys, col_by_keys, aggs, source_table): """Create a new pivot table. Args: row_by_keys: Keys to group rows by. col_by_keys: Keys to group columns by. aggs: Aggregations to perform. source_table: Source table to pivot. Returns: PivotTable: A new pivot table instance. """ (pivot, descriptor) = create_pivot( self.plugin_client, self.tracker, row_by_keys, col_by_keys, aggs, source_table, ) return PivotTable( pivot, descriptor, self, )
[docs] def fetch( session: Session, server_object: pydeephaven.ticket.ServerObject ) -> PivotTable: """Fetch a PivotTable. The ServerObject must refer to a server-side Pivot table, and the session must correspond to the ServerObject. If the object is not a pivot table or the ticket is otherwise invalid, then a DHError or gRPC exception is raised. Args: session: Deephaven Session that corresponds to the ServerObject. server_object: Server object representing the PivotTable. Returns: PivotTable: A new PivotTable instance from the provided Session and ServerObject. """ pc = session.plugin_client(server_object) descriptor_payload, refs = next(pc.resp_stream) descriptor: PivotTableDescriptorMessage = PivotTableDescriptorMessage() descriptor.ParseFromString(descriptor_payload) return PivotTable(server_object, descriptor, PivotClient(session, ticket=refs[0]))