Skip to content

Commit 72fb1a6

Browse files
nardiNardi Lam
andauthored
feat: Allow specifying Duration time unit (#301)
Co-authored-by: Nardi Lam <nardi@gradyent.ai>
1 parent 6e06c73 commit 72fb1a6

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

dataframely/columns/datetime.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def __init__(
440440
max: dt.timedelta | None = None,
441441
max_exclusive: dt.timedelta | None = None,
442442
resolution: str | None = None,
443+
time_unit: TimeUnit = "us",
443444
check: Check | None = None,
444445
alias: str | None = None,
445446
metadata: dict[str, Any] | None = None,
@@ -462,6 +463,7 @@ def __init__(
462463
the formatting language used by :mod:`polars` datetime `truncate` method.
463464
For example, a value `1h` expects all durations to be full hours. Note
464465
that this setting does *not* affect the storage resolution.
466+
time_unit: Unit of time. Defaults to `us` (microseconds).
465467
check: A custom rule or multiple rules to run for this column. This can be:
466468
- A single callable that returns a non-aggregated boolean expression.
467469
The name of the rule is derived from the callable name, or defaults to
@@ -504,10 +506,11 @@ def __init__(
504506
metadata=metadata,
505507
)
506508
self.resolution = resolution
509+
self.time_unit = time_unit
507510

508511
@property
509512
def dtype(self) -> pl.DataType:
510-
return pl.Duration()
513+
return pl.Duration(time_unit=self.time_unit)
511514

512515
def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
513516
result = super().validation_rules(expr)
@@ -526,7 +529,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
526529

527530
@property
528531
def pyarrow_dtype(self) -> pa.DataType:
529-
return pa.duration("us")
532+
return pa.duration(self.time_unit)
530533

531534
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
532535
# NOTE: If no duration is specified, we default to 100 years
@@ -543,6 +546,7 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
543546
default=dt.timedelta(days=365 * 100),
544547
),
545548
resolution=self.resolution,
549+
time_unit=self.time_unit,
546550
null_probability=self._null_probability,
547551
)
548552

dataframely/random.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ def sample_duration(
376376
min: dt.timedelta,
377377
max: dt.timedelta,
378378
resolution: str | None = None,
379+
time_unit: TimeUnit = "us",
379380
null_probability: float = 0.0,
380381
) -> pl.Series:
381382
"""Sample a list of durations in the provided range.
@@ -386,6 +387,7 @@ def sample_duration(
386387
max: The maximum duration to sample (exclusive).
387388
resolution: The resolution that durations in the column must have. This uses
388389
the formatting language used by :mod:`polars` datetime `round` method.
390+
time_unit: The time unit of the duration column. Defaults to `us` (microseconds).
389391
null_probability: The probability of an element being `null`.
390392
391393
Returns:
@@ -410,7 +412,7 @@ def sample_duration(
410412
max=max_microseconds,
411413
null_probability=null_probability,
412414
)
413-
).cast(pl.Duration)
415+
).cast(pl.Duration(time_unit=time_unit))
414416

415417
if resolution is not None:
416418
ref_dt = pl.lit(EPOCH_DATETIME)

tests/columns/test_pyarrow.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,11 @@ def test_datetime_time_unit(time_unit: TimeUnit) -> None:
266266
"test", {"a": dy.Datetime(time_unit=time_unit, nullable=True)}
267267
)
268268
assert str(schema.to_pyarrow_schema()) == f"a: timestamp[{time_unit}]"
269+
270+
271+
@pytest.mark.parametrize("time_unit", ["ns", "us", "ms"])
272+
def test_duration_time_unit(time_unit: TimeUnit) -> None:
273+
schema = create_schema(
274+
"test", {"a": dy.Duration(time_unit=time_unit, nullable=True)}
275+
)
276+
assert str(schema.to_pyarrow_schema()) == f"a: duration[{time_unit}]"

0 commit comments

Comments
 (0)