|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from typing import TYPE_CHECKING |
| 5 | +from typing import Any, Callable |
6 | 6 |
|
7 | 7 | from griffe import ( |
8 | 8 | Attribute, |
9 | 9 | Class, |
10 | 10 | Docstring, |
11 | 11 | Function, |
| 12 | + Kind, |
12 | 13 | get_logger, |
13 | 14 | ) |
14 | 15 | from pydantic.fields import FieldInfo |
15 | 16 |
|
16 | 17 | from griffe_pydantic import common |
17 | 18 |
|
18 | | -if TYPE_CHECKING: |
19 | | - from griffe import ObjectNode |
20 | | - |
21 | 19 | logger = get_logger(__name__) |
22 | 20 |
|
23 | 21 |
|
24 | | -def process_attribute(node: ObjectNode, attr: Attribute, cls: Class) -> None: |
| 22 | +def process_attribute(obj: Any, attr: Attribute, cls: Class, *, processed: set[str]) -> None: |
25 | 23 | """Handle Pydantic fields.""" |
| 24 | + if attr.canonical_path in processed: |
| 25 | + return |
| 26 | + processed.add(attr.canonical_path) |
26 | 27 | if attr.name == "model_config": |
27 | | - cls.extra[common.self_namespace]["config"] = node.obj |
| 28 | + cls.extra[common.self_namespace]["config"] = obj |
28 | 29 | return |
29 | 30 |
|
30 | | - if not isinstance(node.obj, FieldInfo): |
| 31 | + if not isinstance(obj, FieldInfo): |
31 | 32 | return |
32 | 33 |
|
33 | 34 | attr.labels = {"pydantic-field"} |
34 | | - attr.value = node.obj.default |
| 35 | + attr.value = obj.default |
35 | 36 | constraints = {} |
36 | 37 | for constraint in common.field_constraints: |
37 | | - if (value := getattr(node.obj, constraint, None)) is not None: |
| 38 | + if (value := getattr(obj, constraint, None)) is not None: |
38 | 39 | constraints[constraint] = value |
39 | 40 | attr.extra[common.self_namespace]["constraints"] = constraints |
40 | 41 |
|
41 | 42 | # Populate docstring from the field's `description` argument. |
42 | | - if not attr.docstring and (docstring := node.obj.description): |
| 43 | + if not attr.docstring and (docstring := obj.description): |
43 | 44 | attr.docstring = Docstring(docstring, parent=attr) |
44 | 45 |
|
45 | 46 |
|
46 | | -def process_function(node: ObjectNode, func: Function, cls: Class) -> None: |
| 47 | +def process_function(obj: Callable, func: Function, cls: Class, *, processed: set[str]) -> None: |
47 | 48 | """Handle Pydantic field validators.""" |
48 | | - if dec_info := getattr(node.obj, "decorator_info", None): |
| 49 | + if func.canonical_path in processed: |
| 50 | + return |
| 51 | + processed.add(func.canonical_path) |
| 52 | + if dec_info := getattr(obj, "decorator_info", None): |
49 | 53 | common.process_function(func, cls, dec_info.fields) |
50 | 54 |
|
51 | 55 |
|
52 | | -def process_class(node: ObjectNode, cls: Class) -> None: |
| 56 | +def process_class(obj: type, cls: Class, *, processed: set[str], schema: bool = False) -> None: |
53 | 57 | """Detect and prepare Pydantic models.""" |
54 | 58 | common.process_class(cls) |
55 | | - cls.extra[common.self_namespace]["schema"] = common.json_schema(node.obj) |
| 59 | + if schema: |
| 60 | + cls.extra[common.self_namespace]["schema"] = common.json_schema(obj) |
| 61 | + for member in cls.all_members.values(): |
| 62 | + kind = member.kind |
| 63 | + if kind is Kind.ATTRIBUTE: |
| 64 | + process_attribute(getattr(obj, member.name), member, cls, processed=processed) # type: ignore[arg-type] |
| 65 | + elif kind is Kind.FUNCTION: |
| 66 | + process_function(getattr(obj, member.name), member, cls, processed=processed) # type: ignore[arg-type] |
0 commit comments