|
14 | 14 | StringType, |
15 | 15 | StructField, |
16 | 16 | StructType, |
| 17 | + TimestampType, |
17 | 18 | ) |
18 | 19 | import pytest |
19 | 20 |
|
@@ -59,6 +60,39 @@ def test_pandas_to_spark_schema_nullable(self): |
59 | 60 | schema = _pandas_to_spark_schema(df, nullable=False) |
60 | 61 | assert not schema.fields[0].nullable |
61 | 62 |
|
| 63 | + def test_pandas_to_spark_schema_datetime_types(self): |
| 64 | + """Test conversion of pandas datetime types to Spark TimestampType.""" |
| 65 | + import numpy as np |
| 66 | + |
| 67 | + # Create test data with different datetime types |
| 68 | + data = { |
| 69 | + "datetime_ns": pd.to_datetime(["2023-01-01 10:00:00", "2023-01-02 11:00:00"]), |
| 70 | + "datetime_ns_utc": pd.to_datetime(["2023-01-01 10:00:00", "2023-01-02 11:00:00"], utc=True), |
| 71 | + "datetime_ms": pd.to_datetime(["2023-01-01 10:00:00", "2023-01-02 11:00:00"]).astype("datetime64[ms]"), |
| 72 | + "datetime_ms_utc": pd.to_datetime(["2023-01-01 10:00:00", "2023-01-02 11:00:00"], utc=True).tz_localize(None).astype("datetime64[ms]"), |
| 73 | + } |
| 74 | + df = pd.DataFrame(data) |
| 75 | + |
| 76 | + # Convert to Spark schema |
| 77 | + schema = _pandas_to_spark_schema(df) |
| 78 | + |
| 79 | + # Verify the schema |
| 80 | + assert isinstance(schema, StructType) |
| 81 | + assert len(schema.fields) == 4 |
| 82 | + |
| 83 | + # Check that all datetime columns map to TimestampType |
| 84 | + field_dict = {field.name: field for field in schema.fields} |
| 85 | + for field_name in ["datetime_ns", "datetime_ns_utc", "datetime_ms", "datetime_ms_utc"]: |
| 86 | + assert isinstance(field_dict[field_name].dataType, TimestampType), \ |
| 87 | + f"Field {field_name} should be TimestampType, got {type(field_dict[field_name].dataType)}" |
| 88 | + assert field_dict[field_name].nullable |
| 89 | + |
| 90 | + # Verify the actual pandas dtypes to ensure our test data has the expected types |
| 91 | + assert str(df["datetime_ns"].dtype) == "datetime64[ns]" |
| 92 | + assert str(df["datetime_ns_utc"].dtype) == "datetime64[ns, UTC]" |
| 93 | + assert str(df["datetime_ms"].dtype) == "datetime64[ms]" |
| 94 | + assert str(df["datetime_ms_utc"].dtype) == "datetime64[ms]" |
| 95 | + |
62 | 96 |
|
63 | 97 | # Completely isolated test class for QueryAPIDataCloudReader |
64 | 98 | @pytest.mark.usefixtures("patch_all_requests") |
|
0 commit comments