"""Class-based api for pandas models."""
import copy
import inspect
import os
import re
import typing
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import pandas as pd
import pandas.util
from pandera.api.base.model import BaseModel
from pandera.api.checks import Check
from pandera.api.pandas.components import Column, Index, MultiIndex
from pandera.api.pandas.container import DataFrameSchema
from pandera.api.pandas.model_components import (
CHECK_KEY,
DATAFRAME_CHECK_KEY,
CheckInfo,
Field,
FieldCheckInfo,
FieldInfo,
)
from pandera.api.pandas.model_config import BaseConfig
from pandera.engines import PYDANTIC_V2
from pandera.errors import SchemaInitError
from pandera.strategies import pandas_strategies as st
from pandera.typing import INDEX_TYPES, SERIES_TYPES, AnnotationInfo
from pandera.typing.common import DataFrameBase
if PYDANTIC_V2:
from pydantic_core import core_schema
from pydantic import GetJsonSchemaHandler, GetCoreSchemaHandler
try:
from typing_extensions import get_type_hints
except ImportError: # pragma: no cover
from typing import get_type_hints # type: ignore
SchemaIndex = Union[Index, MultiIndex]
_CONFIG_KEY = "Config"
MODEL_CACHE: Dict[Type["DataFrameModel"], DataFrameSchema] = {}
GENERIC_SCHEMA_CACHE: Dict[
Tuple[Type["DataFrameModel"], Tuple[Type[Any], ...]],
Type["DataFrameModel"],
] = {}
F = TypeVar("F", bound=Callable)
TDataFrameModel = TypeVar("TDataFrameModel", bound="DataFrameModel")
def docstring_substitution(*args: Any, **kwargs: Any) -> Callable[[F], F]:
"""Typed wrapper around pandas.util.Substitution."""
def decorator(func: F) -> F:
substitutor = pandas.util.Substitution(*args, **kwargs) # type: ignore[attr-defined]
return cast(F, substitutor(func))
return decorator
def _is_field(name: str) -> bool:
"""Ignore private and reserved keywords."""
return not name.startswith("_") and name != _CONFIG_KEY
_config_options = [attr for attr in vars(BaseConfig) if _is_field(attr)]
def _extract_config_options_and_extras(
config: Any,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
config_options, extras = {}, {}
for name, value in vars(config).items():
if name in _config_options:
config_options[name] = value
elif _is_field(name):
extras[name] = value
# drop private/reserved keywords
return config_options, extras
def _convert_extras_to_checks(extras: Dict[str, Any]) -> List[Check]:
"""
New in GH#383.
Any key not in BaseConfig keys is interpreted as defining a dataframe check. This function
defines this conversion as follows:
- Look up the key name in Check
- If value is
- tuple: interpret as args
- dict: interpret as kwargs
- anything else: interpret as the only argument to pass to Check
"""
checks = []
for name, value in extras.items():
if isinstance(value, tuple):
args, kwargs = value, {}
elif isinstance(value, dict):
args, kwargs = (), value
else:
args, kwargs = (value,), {}
# dispatch directly to getattr to raise the correct exception
checks.append(Check.__getattr__(name)(*args, **kwargs))
return checks
class DataFrameModel(BaseModel):
"""Definition of a :class:`~pandera.api.pandas.container.DataFrameSchema`.
*new in 0.5.0*
.. important::
This class is the new name for ``SchemaModel``, which will be deprecated
in pandera version ``0.20.0``.
See the :ref:`User Guide <dataframe_models>` for more.
"""
Config: Type[BaseConfig] = BaseConfig
__extras__: Optional[Dict[str, Any]] = None
__schema__: Optional[DataFrameSchema] = None
__config__: Optional[Type[BaseConfig]] = None
#: Key according to `FieldInfo.name`
__fields__: Mapping[str, Tuple[AnnotationInfo, FieldInfo]] = {}
__checks__: Dict[str, List[Check]] = {}
__root_checks__: List[Check] = []
@docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__)
def __new__(cls, *args, **kwargs) -> DataFrameBase[TDataFrameModel]: # type: ignore [misc]
"""%(validate_doc)s"""
return cast(
DataFrameBase[TDataFrameModel], cls.validate(*args, **kwargs)
)
def __init_subclass__(cls, **kwargs):
"""Ensure :class:`~pandera.api.pandas.model_components.FieldInfo` instances."""
if "Config" in cls.__dict__:
cls.Config.name = (
cls.Config.name
if hasattr(cls.Config, "name")
else cls.__name__
)
else:
cls.Config = type("Config", (BaseConfig,), {"name": cls.__name__})
super().__init_subclass__(**kwargs)
# pylint:disable=no-member
subclass_annotations = cls.__dict__.get("__annotations__", {})
for field_name in subclass_annotations.keys():
if _is_field(field_name) and field_name not in cls.__dict__:
# Field omitted
field = Field()
field.__set_name__(cls, field_name)
setattr(cls, field_name, field)
cls.__config__, cls.__extras__ = cls._collect_config_and_extras()
def __class_getitem__(
cls: Type[TDataFrameModel],
params: Union[Type[Any], Tuple[Type[Any], ...]],
) -> Type[TDataFrameModel]:
"""Parameterize the class's generic arguments with the specified types"""
if not hasattr(cls, "__parameters__"):
raise TypeError(
f"{cls.__name__} must inherit from typing.Generic before being parameterized"
)
# pylint: disable=no-member
__parameters__: Tuple[TypeVar, ...] = cls.__parameters__ # type: ignore
if not isinstance(params, tuple):
params = (params,)
if len(params) != len(__parameters__):
raise ValueError(
f"Expected {len(__parameters__)} generic arguments but found {len(params)}"
)
if (cls, params) in GENERIC_SCHEMA_CACHE:
return typing.cast(
Type[TDataFrameModel], GENERIC_SCHEMA_CACHE[(cls, params)]
)
param_dict: Dict[TypeVar, Type[Any]] = dict(
zip(__parameters__, params)
)
extra: Dict[str, Any] = {"__annotations__": {}}
for field, (annot_info, field_info) in cls._collect_fields().items():
if isinstance(annot_info.arg, TypeVar):
if annot_info.arg in param_dict:
raw_annot = annot_info.origin[param_dict[annot_info.arg]] # type: ignore
if annot_info.optional:
raw_annot = Optional[raw_annot]
extra["__annotations__"][field] = raw_annot
extra[field] = copy.deepcopy(field_info)
parameterized_name = (
f"{cls.__name__}[{', '.join(p.__name__ for p in params)}]"
)
parameterized_cls = type(parameterized_name, (cls,), extra)
GENERIC_SCHEMA_CACHE[(cls, params)] = parameterized_cls
return parameterized_cls
@classmethod
def to_schema(cls) -> DataFrameSchema:
"""Create :class:`~pandera.DataFrameSchema` from the :class:`.DataFrameModel`."""
if cls in MODEL_CACHE:
return MODEL_CACHE[cls]
mi_kwargs = {
name[len("multiindex_") :]: value
for name, value in vars(cls.__config__).items()
if name.startswith("multiindex_")
}
cls.__fields__ = cls._collect_fields()
for field, (annot_info, _) in cls.__fields__.items():
if isinstance(annot_info.arg, TypeVar):
raise SchemaInitError(f"Field {field} has a generic data type")
check_infos = typing.cast(
List[FieldCheckInfo], cls._collect_check_infos(CHECK_KEY)
)
cls.__checks__ = cls._extract_checks(
check_infos, field_names=list(cls.__fields__.keys())
)
df_check_infos = cls._collect_check_infos(DATAFRAME_CHECK_KEY)
df_custom_checks = cls._extract_df_checks(df_check_infos)
df_registered_checks = _convert_extras_to_checks(
{} if cls.__extras__ is None else cls.__extras__
)
cls.__root_checks__ = df_custom_checks + df_registered_checks
columns, index = cls._build_columns_index(
cls.__fields__, cls.__checks__, **mi_kwargs
)
kwargs = {}
if cls.__config__ is not None:
kwargs = {
"dtype": cls.__config__.dtype,
"coerce": cls.__config__.coerce,
"strict": cls.__config__.strict,
"name": cls.__config__.name,
"ordered": cls.__config__.ordered,
"unique": cls.__config__.unique,
"title": cls.__config__.title,
"description": cls.__config__.description or cls.__doc__,
"unique_column_names": cls.__config__.unique_column_names,
"add_missing_columns": cls.__config__.add_missing_columns,
"drop_invalid_rows": cls.__config__.drop_invalid_rows,
}
cls.__schema__ = DataFrameSchema(
columns,
index=index,
checks=cls.__root_checks__, # type: ignore
**kwargs, # type: ignore
)
if cls not in MODEL_CACHE:
MODEL_CACHE[cls] = cls.__schema__ # type: ignore
return cls.__schema__ # type: ignore
@classmethod
def to_yaml(cls, stream: Optional[os.PathLike] = None):
"""
Convert `Schema` to yaml using `io.to_yaml`.
"""
return cls.to_schema().to_yaml(stream)
@classmethod
@docstring_substitution(validate_doc=DataFrameSchema.validate.__doc__)
def validate(
cls: Type[TDataFrameModel],
check_obj: pd.DataFrame,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> DataFrameBase[TDataFrameModel]:
"""%(validate_doc)s"""
return cast(
DataFrameBase[TDataFrameModel],
cls.to_schema().validate(
check_obj, head, tail, sample, random_state, lazy, inplace
),
)
@classmethod
@docstring_substitution(strategy_doc=DataFrameSchema.strategy.__doc__)
@st.strategy_import_error
def strategy(cls: Type[TDataFrameModel], **kwargs):
"""%(strategy_doc)s"""
return cls.to_schema().strategy(**kwargs)
@classmethod
@docstring_substitution(example_doc=DataFrameSchema.strategy.__doc__)
@st.strategy_import_error
def example(
cls: Type[TDataFrameModel],
**kwargs,
) -> DataFrameBase[TDataFrameModel]:
"""%(example_doc)s"""
return cast(
DataFrameBase[TDataFrameModel], cls.to_schema().example(**kwargs)
)
@classmethod
def _build_columns_index( # pylint:disable=too-many-locals
cls,
fields: Dict[str, Tuple[AnnotationInfo, FieldInfo]],
checks: Dict[str, List[Check]],
**multiindex_kwargs: Any,
) -> Tuple[Dict[str, Column], Optional[Union[Index, MultiIndex]],]:
index_count = sum(
annotation.origin in INDEX_TYPES
for annotation, _ in fields.values()
)
columns: Dict[str, Column] = {}
indices: List[Index] = []
for field_name, (annotation, field) in fields.items():
field_checks = checks.get(field_name, [])
field_name = field.name
check_name = getattr(field, "check_name", None)
if annotation.metadata:
if field.dtype_kwargs:
raise TypeError(
"Cannot specify redundant 'dtype_kwargs' "
+ f"for {annotation.raw_annotation}."
+ "\n Usage Tip: Drop 'typing.Annotated'."
)
dtype_kwargs = _get_dtype_kwargs(annotation)
dtype = annotation.arg(**dtype_kwargs) # type: ignore
elif annotation.default_dtype:
dtype = annotation.default_dtype
else:
dtype = annotation.arg
dtype = None if dtype is Any else dtype
if (
annotation.origin is None
or annotation.origin in SERIES_TYPES
or annotation.raw_annotation in SERIES_TYPES
):
col_constructor = field.to_column if field else Column
if check_name is False:
raise SchemaInitError(
f"'check_name' is not supported for {field_name}."
)
columns[field_name] = col_constructor( # type: ignore
dtype,
required=not annotation.optional,
checks=field_checks,
name=field_name,
)
elif (
annotation.origin in INDEX_TYPES
or annotation.raw_annotation in INDEX_TYPES
):
if annotation.optional:
raise SchemaInitError(
f"Index '{field_name}' cannot be Optional."
)
if check_name is False or (
# default single index
check_name is None
and index_count == 1
):
field_name = None # type:ignore
index_constructor = field.to_index if field else Index
index = index_constructor( # type: ignore
dtype, checks=field_checks, name=field_name
)
indices.append(index)
else:
raise SchemaInitError(
f"Invalid annotation '{field_name}: "
f"{annotation.raw_annotation}'"
)
return columns, _build_schema_index(indices, **multiindex_kwargs)
@classmethod
def _get_model_attrs(cls) -> Dict[str, Any]:
"""Return all attributes.
Similar to inspect.get_members but bypass descriptors __get__.
"""
bases = inspect.getmro(cls)[:-1] # bases -> DataFrameModel -> object
attrs = {}
for base in reversed(bases):
if issubclass(base, DataFrameModel):
attrs.update(base.__dict__)
return attrs
@classmethod
def _collect_fields(cls) -> Dict[str, Tuple[AnnotationInfo, FieldInfo]]:
"""Centralize publicly named fields and their corresponding annotations."""
# pylint: disable=unexpected-keyword-arg
annotations = get_type_hints( # type: ignore[call-arg]
cls,
include_extras=True,
)
# pylint: enable=unexpected-keyword-arg
attrs = cls._get_model_attrs()
missing = []
for name, attr in attrs.items():
if inspect.isroutine(attr):
continue
if not _is_field(name):
annotations.pop(name, None)
elif name not in annotations:
missing.append(name)
if missing:
raise SchemaInitError(f"Found missing annotations: {missing}")
fields = {}
for field_name, annotation in annotations.items():
field = attrs[field_name] # __init_subclass__ guarantees existence
if not isinstance(field, FieldInfo):
raise SchemaInitError(
f"'{field_name}' can only be assigned a 'Field', "
+ f"not a '{type(field)}.'"
)
fields[field.name] = (AnnotationInfo(annotation), field)
return fields
@classmethod
def _collect_config_and_extras(
cls,
) -> Tuple[Type[BaseConfig], Dict[str, Any]]:
"""Collect config options from bases, splitting off unknown options."""
bases = inspect.getmro(cls)[:-1]
bases = tuple(
base for base in bases if issubclass(base, DataFrameModel)
)
root_model, *models = reversed(bases)
options, extras = _extract_config_options_and_extras(root_model.Config)
for model in models:
config = getattr(model, _CONFIG_KEY, {})
base_options, base_extras = _extract_config_options_and_extras(
config
)
options.update(base_options)
extras.update(base_extras)
return type("Config", (BaseConfig,), options), extras
@classmethod
def _collect_check_infos(cls, key: str) -> List[CheckInfo]:
"""Collect inherited check metadata from bases.
Inherited classmethods are not in cls.__dict__, that's why we need to
walk the inheritance tree.
"""
bases = inspect.getmro(cls)[:-2] # bases -> DataFrameModel -> object
bases = tuple(
base for base in bases if issubclass(base, DataFrameModel)
)
method_names = set()
check_infos = []
for base in bases:
for attr_name, attr_value in vars(base).items():
check_info = getattr(attr_value, key, None)
if not isinstance(check_info, CheckInfo):
continue
if attr_name in method_names: # check overridden by subclass
continue
method_names.add(attr_name)
check_infos.append(check_info)
return check_infos
@classmethod
def _extract_checks(
cls, check_infos: List[FieldCheckInfo], field_names: List[str]
) -> Dict[str, List[Check]]:
"""Collect field annotations from bases in mro reverse order."""
checks: Dict[str, List[Check]] = {}
for check_info in check_infos:
check_info_fields = {
field.name if isinstance(field, FieldInfo) else field
for field in check_info.fields
}
if check_info.regex:
matched = _regex_filter(field_names, check_info_fields)
else:
matched = check_info_fields
check_ = check_info.to_check(cls)
for field in matched:
if field not in field_names:
raise SchemaInitError(
f"Check {check_.name} is assigned to a non-existing field '{field}'."
)
if field not in checks:
checks[field] = []
checks[field].append(check_)
return checks
@classmethod
def _extract_df_checks(cls, check_infos: List[CheckInfo]) -> List[Check]:
"""Collect field annotations from bases in mro reverse order."""
return [check_info.to_check(cls) for check_info in check_infos]
@classmethod
def get_metadata(cls) -> Optional[dict]:
"""Provide metadata for columns and schema level"""
res: Dict[Any, Any] = {"columns": {}}
columns = cls._collect_fields()
for k, (_, v) in columns.items():
res["columns"][k] = v.properties["metadata"]
res["dataframe"] = cls.Config.metadata
meta = {}
meta[cls.Config.name] = res
return meta
@classmethod
def pydantic_validate(cls, schema_model: Any) -> "DataFrameModel":
"""Verify that the input is a compatible dataframe model."""
if not inspect.isclass(schema_model): # type: ignore
raise TypeError(f"{schema_model} is not a pandera.DataFrameModel")
if not issubclass(schema_model, cls): # type: ignore
raise TypeError(f"{schema_model} does not inherit {cls}.")
try:
schema_model.to_schema()
except SchemaInitError as exc:
raise ValueError(
f"Cannot use {cls} as a pydantic type as its "
"DataFrameModel cannot be converted to a DataFrameSchema.\n"
f"Please revisit the model to address the following errors:"
f"\n{exc}"
) from exc
return cast("DataFrameModel", schema_model)
if PYDANTIC_V2:
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls.pydantic_validate,
)
@classmethod
def __get_pydantic_json_schema__(
cls,
_core_schema: core_schema.CoreSchema,
_handler: GetJsonSchemaHandler,
):
"""Update pydantic field schema."""
json_schema = _handler(_core_schema)
json_schema = _handler.resolve_ref_schema(json_schema)
json_schema.update(_to_json_schema(cls.to_schema()))
else:
@classmethod
def __modify_schema__(cls, field_schema):
"""Update pydantic field schema."""
field_schema.update(_to_json_schema(cls.to_schema()))
@classmethod
def __get_validators__(cls):
yield cls.pydantic_validate
SchemaModel = DataFrameModel
"""
Alias for DataFrameModel.
.. warning::
This subclass is necessary for backwards compatibility, and will be
deprecated in pandera version ``0.20.0`` in favor of
:py:class:`~pandera.api.pandas.model.DataFrameModel`
"""
def _build_schema_index(
indices: List[Index], **multiindex_kwargs: Any
) -> Optional[SchemaIndex]:
index: Optional[SchemaIndex] = None
if indices:
if len(indices) == 1:
index = indices[0]
else:
index = MultiIndex(indices, **multiindex_kwargs)
return index
def _regex_filter(seq: Iterable, regexps: Iterable[str]) -> Set[str]:
"""Filter items matching at least one of the regexes."""
matched: Set[str] = set()
for regex in regexps:
pattern = re.compile(regex)
matched.update(filter(pattern.match, seq))
return matched
def _get_dtype_kwargs(annotation: AnnotationInfo) -> Dict[str, Any]:
sig = inspect.signature(annotation.arg) # type: ignore
dtype_arg_names = list(sig.parameters.keys())
if len(annotation.metadata) != len(dtype_arg_names): # type: ignore
raise TypeError(
f"Annotation '{annotation.arg.__name__}' requires " # type: ignore
+ f"all positional arguments {dtype_arg_names}."
)
return dict(zip(dtype_arg_names, annotation.metadata)) # type: ignore
def _to_json_schema(dataframe_schema):
"""Serialize schema metadata into json-schema format.
:param dataframe_schema: schema to write to json-schema format.
.. note::
This function is currently does not fully specify a pandera schema,
and is primarily used internally to render OpenAPI docs via the
FastAPI integration.
"""
empty = pd.DataFrame(columns=dataframe_schema.columns.keys()).astype(
{k: v.type for k, v in dataframe_schema.dtypes.items()}
)
table_schema = pd.io.json.build_table_schema(empty)
def _field_json_schema(field):
return {
"type": "array",
"items": {"type": field["type"]},
}
return {
"title": dataframe_schema.name or "pandera.DataFrameSchema",
"type": "object",
"properties": {
field["name"]: _field_json_schema(field)
for field in table_schema["fields"]
},
}