|
2 | 2 | import itertools |
3 | 3 | import math |
4 | 4 | from copy import deepcopy |
5 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
| 5 | +from typing import ( |
| 6 | + Any, |
| 7 | + Callable, |
| 8 | + Dict, |
| 9 | + Iterable, |
| 10 | + List, |
| 11 | + Literal, |
| 12 | + Optional, |
| 13 | + Sequence, |
| 14 | + Tuple, |
| 15 | + Union, |
| 16 | +) |
6 | 17 |
|
7 | 18 | import numpy as np |
8 | 19 | from sortedcollections.recipes import ItemSortedDict |
@@ -64,7 +75,11 @@ def abs_min_log_loss(xs, ys): |
64 | 75 |
|
65 | 76 | @uses_nth_neighbors(1) |
66 | 77 | def triangle_loss( |
67 | | - xs: Sequence[float], ys: Union[Iterable[float], Iterable[Iterable[float]]] |
| 78 | + xs: Sequence[Union[float, None]], |
| 79 | + ys: Union[ |
| 80 | + Iterable[Union[float, None]], |
| 81 | + Iterable[Union[Iterable[float], None]], |
| 82 | + ], |
68 | 83 | ) -> float: |
69 | 84 | xs = [x for x in xs if x is not None] |
70 | 85 | ys = [y for y in ys if y is not None] |
@@ -399,7 +414,7 @@ def _update_scale(self, x: float, y: Union[float, np.ndarray]) -> None: |
399 | 414 | self._bbox[1][1] = max(self._bbox[1][1], y) |
400 | 415 | self._scale[1] = self._bbox[1][1] - self._bbox[1][0] |
401 | 416 |
|
402 | | - def tell(self, x: float, y: Union[float, np.ndarray]) -> None: |
| 417 | + def tell(self, x: float, y: Union[float, Sequence[float], np.ndarray]) -> None: |
403 | 418 | if x in self.data: |
404 | 419 | # The point is already evaluated before |
405 | 420 | return |
@@ -442,7 +457,7 @@ def tell_pending(self, x: float) -> None: |
442 | 457 | self._update_neighbors(x, self.neighbors_combined) |
443 | 458 | self._update_losses(x, real=False) |
444 | 459 |
|
445 | | - def tell_many(self, xs: List[float], ys: List[Any], *, force=False) -> None: |
| 460 | + def tell_many(self, xs: Sequence[float], ys: Sequence[Any], *, force=False) -> None: |
446 | 461 | if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2): |
447 | 462 | # Only run this more efficient method if there are |
448 | 463 | # at least 2 points and the amount of points added are |
@@ -602,7 +617,7 @@ def _loss(self, mapping: ItemSortedDict, ival: Any) -> Any: |
602 | 617 | loss = mapping[ival] |
603 | 618 | return finite_loss(ival, loss, self._scale[0]) |
604 | 619 |
|
605 | | - def plot(self, *, scatter_or_line="scatter"): |
| 620 | + def plot(self, *, scatter_or_line: Literal["scatter", "line"] = "scatter"): |
606 | 621 | """Returns a plot of the evaluated data. |
607 | 622 |
|
608 | 623 | Parameters |
|
0 commit comments