|
4 | 4 | from collections import defaultdict |
5 | 5 | from math import sqrt |
6 | 6 | from operator import attrgetter |
7 | | -from typing import Callable, List, Optional, Set, Tuple, Union |
| 7 | +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union |
8 | 8 |
|
9 | 9 | import numpy as np |
10 | 10 | from scipy.linalg import norm |
@@ -142,15 +142,20 @@ class _Interval: |
142 | 142 | def __init__( |
143 | 143 | self, a: Union[int, float], b: Union[int, float], depth: int, rdepth: int |
144 | 144 | ) -> None: |
145 | | - self.children = [] |
146 | | - self.data = {} |
| 145 | + self.children: List["_Interval"] = [] |
| 146 | + self.data: Dict[float, float] = {} |
147 | 147 | self.a = a |
148 | 148 | self.b = b |
149 | 149 | self.depth = depth |
150 | 150 | self.rdepth = rdepth |
151 | | - self.done_leaves = set() |
152 | | - self.depth_complete = None |
| 151 | + self.done_leaves: Set["_Interval"] = set() |
| 152 | + self.depth_complete: Optional[int] = None |
153 | 153 | self.removed = False |
| 154 | + if TYPE_CHECKING: |
| 155 | + self.ndiv: int |
| 156 | + self.parent: Optional["_Interval"] |
| 157 | + self.err: float |
| 158 | + self.c: np.ndarray |
154 | 159 |
|
155 | 160 | @classmethod |
156 | 161 | def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval": |
@@ -234,7 +239,7 @@ def calc_err(self, c_old: np.ndarray) -> float: |
234 | 239 |
|
235 | 240 | def calc_ndiv(self) -> None: |
236 | 241 | div = self.parent.c00 and self.c00 / self.parent.c00 > 2 |
237 | | - self.ndiv += div |
| 242 | + self.ndiv += int(div) |
238 | 243 |
|
239 | 244 | if self.ndiv > ndiv_max and 2 * self.ndiv > self.rdepth: |
240 | 245 | raise DivergentIntegralError |
@@ -378,12 +383,14 @@ def __init__(self, function: Callable, bounds: Tuple[int, int], tol: float) -> N |
378 | 383 | self.bounds = bounds |
379 | 384 | self.tol = tol |
380 | 385 | self.max_ivals = 1000 |
381 | | - self.priority_split = [] |
| 386 | + self.priority_split: List[_Interval] = [] |
382 | 387 | self.data = {} |
383 | 388 | self.pending_points = set() |
384 | | - self._stack = [] |
385 | | - self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth"))) |
386 | | - self.ivals = set() |
| 389 | + self._stack: List[float] = [] |
| 390 | + self.x_mapping: Dict[float, SortedSet] = defaultdict( |
| 391 | + lambda: SortedSet([], key=attrgetter("rdepth")) |
| 392 | + ) |
| 393 | + self.ivals: Set[_Interval] = set() |
387 | 394 | ival = _Interval.make_first(*self.bounds) |
388 | 395 | self.add_ival(ival) |
389 | 396 | self.first_ival = ival |
|
0 commit comments