Source code for pavise.pandas

"""Pandas backend for type-parameterized DataFrame with Protocol-based schema validation."""

from typing import (
    Any,
    Generic,
    Literal,
    TypeVar,
    get_args,
    get_origin,
    get_type_hints,
)

import pandas as pd

from pavise._pandas.validation import INDEX_COLUMN_NAME, validate_dataframe
from pavise.types import NotRequiredColumn

__all__ = ["DataFrame", "NotRequiredColumn"]

SchemaT_co = TypeVar("SchemaT_co", covariant=True)


def _get_dtype_for_type(base_type: type) -> str | pd.api.extensions.ExtensionDtype:
    """
    Get pandas dtype for a given Python type.

    Args:
        base_type: Python type (int, str, float, bool, datetime, date, timedelta)

    Returns:
        String representation of pandas dtype
    """
    from pavise._pandas.validation import TYPE_TO_DTYPE

    if isinstance(base_type, type) and issubclass(base_type, pd.api.extensions.ExtensionDtype):
        return base_type()

    return TYPE_TO_DTYPE.get(base_type, "object")


[docs] class DataFrame(pd.DataFrame, Generic[SchemaT_co]): """ Type-parameterized DataFrame with runtime validation for pandas. Usage:: # Static type checking only def process(df: DataFrame[UserSchema]) -> DataFrame[UserSchema]: return df # Runtime validation validated = DataFrame[UserSchema](raw_df) The type parameter is covariant, allowing structural subtyping. DataFrame[ChildSchema] is compatible with DataFrame[ParentSchema] when ChildSchema has all columns of ParentSchema. """ _schema: type | None = None
[docs] def __class_getitem__(cls, schema: type): """Create a new DataFrame class with schema validation.""" class TypedDataFrame(DataFrame): _schema = schema return TypedDataFrame
[docs] def __new__(cls, data: Any = None, *args: Any, strict: bool = False, **kwargs: Any): """Create a new DataFrame instance.""" return super().__new__(cls)
[docs] def __init__(self, data: Any = None, *args: Any, strict: bool = False, **kwargs: Any) -> None: """ Initialize DataFrame with optional schema validation. Args: data: Data to create DataFrame from *args: Additional arguments passed to pd.DataFrame strict: If True, raise error on extra columns not in schema **kwargs: Additional keyword arguments passed to pd.DataFrame Raises: ValueError: If required column is missing TypeError: If column has wrong type """ pd.DataFrame.__init__(self, data, *args, **kwargs) # type: ignore[misc] if self._schema is not None: validate_dataframe(self, self._schema, strict=strict)
[docs] @classmethod def make_empty(cls) -> "DataFrame[SchemaT_co]": """ Create an empty DataFrame with columns from the schema. Returns: DataFrame: Empty DataFrame with correct column types """ if cls._schema is None: return cls({}) from pavise._pandas.validation import ( _extract_index_name_type_and_validators, _extract_type_and_validators, ) type_hints = get_type_hints(cls._schema, include_extras=True) columns = {} index_name = None index_base_type = None for col_name, col_type in type_hints.items(): if col_name == INDEX_COLUMN_NAME: # Extract index name from schema index_base_type, index_name, _validators, _is_optional = ( _extract_index_name_type_and_validators(col_type) ) continue base_type, _validators, is_optional, _is_not_required = _extract_type_and_validators( col_type ) # Handle Union types (represented as tuple) - use first type if isinstance(base_type, tuple): base_type = base_type[0] if get_origin(base_type) is Literal: literal_values = get_args(base_type) if literal_values: first_value = literal_values[0] base_type = type(first_value) dtype = _get_dtype_for_type(base_type) columns[col_name] = pd.Series([], dtype=dtype) # Create pandas DataFrame first, set index name, then convert to typed DataFrame raw_df = pd.DataFrame(columns) if index_name is not None: if isinstance(index_name, tuple): # MultiIndex: create empty MultiIndex with names level_types = get_args(index_base_type) level_arrays = [ pd.array([], dtype=_get_dtype_for_type(level_type)) for level_type in level_types ] raw_df.index = pd.MultiIndex.from_arrays(level_arrays, names=list(index_name)) else: # Single index: set index name raw_df.index.name = index_name return cls(raw_df)