|
13 | 13 | IncrementalByTimeRangeKind, |
14 | 14 | IncrementalByUniqueKeyKind, |
15 | 15 | TimeColumn, |
| 16 | + SCDType2ByColumnKind, |
16 | 17 | ) |
17 | 18 | from sqlmesh.core.node import IntervalUnit |
18 | 19 | from sqlmesh.core.scheduler import ( |
@@ -810,3 +811,69 @@ def signal_base(batch: DatetimeRanges): |
810 | 811 | snapshot_b: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], |
811 | 812 | snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], |
812 | 813 | } |
| 814 | + |
| 815 | + |
| 816 | +@pytest.mark.parametrize( |
| 817 | + "batch_size, expected_batches", |
| 818 | + [ |
| 819 | + ( |
| 820 | + 1, |
| 821 | + [ |
| 822 | + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), |
| 823 | + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), |
| 824 | + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), |
| 825 | + ], |
| 826 | + ), |
| 827 | + ( |
| 828 | + None, |
| 829 | + [ |
| 830 | + (to_timestamp("2023-01-01"), to_timestamp("2023-01-04")), |
| 831 | + ], |
| 832 | + ), |
| 833 | + ], |
| 834 | +) |
| 835 | +def test_scd_type_2_batch_size( |
| 836 | + mocker: MockerFixture, |
| 837 | + make_snapshot, |
| 838 | + get_batched_missing_intervals, |
| 839 | + batch_size: t.Optional[int], |
| 840 | + expected_batches: t.List[t.Tuple[int, int]], |
| 841 | +): |
| 842 | + """ |
| 843 | + Test that SCD_TYPE_2_BY_COLUMN models are batched correctly based on batch_size. |
| 844 | + With batch_size=1, we expect 3 separate batches for 3 days. |
| 845 | + Without a specified batch_size, we expect a single batch for the entire period. |
| 846 | + """ |
| 847 | + start = to_datetime("2023-01-01") |
| 848 | + end = to_datetime("2023-01-04") |
| 849 | + |
| 850 | + # Configure kind params |
| 851 | + kind_params = {} |
| 852 | + if batch_size is not None: |
| 853 | + kind_params["batch_size"] = batch_size |
| 854 | + |
| 855 | + # Create the model and snapshot |
| 856 | + model = SqlModel( |
| 857 | + name="test_scd_model", |
| 858 | + kind=SCDType2ByColumnKind(columns="valid_to", unique_key=["id"], **kind_params), |
| 859 | + cron="@daily", |
| 860 | + start=start, |
| 861 | + query=parse_one("SELECT id, valid_from, valid_to FROM source"), |
| 862 | + ) |
| 863 | + snapshot = make_snapshot(model) |
| 864 | + |
| 865 | + # Setup scheduler |
| 866 | + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) |
| 867 | + scheduler = Scheduler( |
| 868 | + snapshots=[snapshot], |
| 869 | + snapshot_evaluator=snapshot_evaluator, |
| 870 | + state_sync=mocker.MagicMock(), |
| 871 | + max_workers=2, |
| 872 | + default_catalog=None, |
| 873 | + ) |
| 874 | + |
| 875 | + # Get batches for the time period |
| 876 | + batches = get_batched_missing_intervals(scheduler, start, end, end)[snapshot] |
| 877 | + |
| 878 | + # Verify batches match expectations |
| 879 | + assert batches == expected_batches |
0 commit comments