|
35 | 35 |
|
36 | 36 | import abc |
37 | 37 | import collections |
| 38 | +from itertools import chain |
38 | 39 | import logging |
39 | 40 | import sys |
40 | 41 | import time |
@@ -1076,15 +1077,20 @@ def format( |
1076 | 1077 | **kwargs: t.Any, |
1077 | 1078 | ) -> bool: |
1078 | 1079 | """Format all SQL models and audits.""" |
| 1080 | + filtered_targets = [ |
| 1081 | + target |
| 1082 | + for target in chain(self._models.values(), self._audits.values()) |
| 1083 | + if target._path is not None |
| 1084 | + and target._path.suffix == ".sql" |
| 1085 | + and (not paths or any(target._path.samefile(p) for p in paths)) |
| 1086 | + ] |
1079 | 1087 | unformatted_file_paths = [] |
1080 | | - format_targets = {**self._models, **self._audits} |
1081 | 1088 |
|
1082 | | - for target in format_targets.values(): |
1083 | | - if target._path is None or target._path.suffix != ".sql": |
1084 | | - continue |
1085 | | - if paths and not any(target._path.samefile(p) for p in paths): |
| 1089 | + for target in filtered_targets: |
| 1090 | + if ( |
| 1091 | + target._path is None |
| 1092 | + ): # introduced to satisfy type checker as still want to pull filter out as many targets as possible before loop |
1086 | 1093 | continue |
1087 | | - |
1088 | 1094 | with open(target._path, "r+", encoding="utf-8") as file: |
1089 | 1095 | before = file.read() |
1090 | 1096 | expressions = parse(before, default_dialect=self.config_for_node(target).dialect) |
@@ -1125,7 +1131,6 @@ def format( |
1125 | 1131 | if unformatted_file_paths: |
1126 | 1132 | for path in unformatted_file_paths: |
1127 | 1133 | self.console.log_status_update(f"{path} needs reformatting.") |
1128 | | - |
1129 | 1134 | self.console.log_status_update( |
1130 | 1135 | f"\n{len(unformatted_file_paths)} file(s) need reformatting." |
1131 | 1136 | ) |
|
0 commit comments