|
1 | 1 | import asyncio |
2 | 2 | import datetime |
| 3 | +import decimal |
3 | 4 | import functools |
4 | 5 | import os |
5 | 6 |
|
@@ -61,6 +62,14 @@ def process_result_value(self, value, dialect): |
61 | 62 | sqlalchemy.Column("published", MyEpochType), |
62 | 63 | ) |
63 | 64 |
|
| 65 | +# Used to test Numeric |
| 66 | +prices = sqlalchemy.Table( |
| 67 | + "prices", |
| 68 | + metadata, |
| 69 | + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), |
| 70 | + sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)), |
| 71 | +) |
| 72 | + |
64 | 73 |
|
65 | 74 | @pytest.fixture(autouse=True, scope="module") |
66 | 75 | def create_test_database(): |
@@ -456,6 +465,33 @@ async def test_datetime_field(database_url): |
456 | 465 | assert results[0]["published"] == now |
457 | 466 |
|
458 | 467 |
|
| 468 | +@pytest.mark.parametrize("database_url", DATABASE_URLS) |
| 469 | +@async_adapter |
| 470 | +async def test_decimal_field(database_url): |
| 471 | + """ |
| 472 | + Test Decimal (NUMERIC) columns, to ensure records are coerced to/from proper Python types. |
| 473 | + """ |
| 474 | + |
| 475 | + async with Database(database_url) as database: |
| 476 | + async with database.transaction(force_rollback=True): |
| 477 | + price = decimal.Decimal("0.700000000000001") |
| 478 | + |
| 479 | + # execute() |
| 480 | + query = prices.insert() |
| 481 | + values = {"price": price} |
| 482 | + await database.execute(query, values) |
| 483 | + |
| 484 | + # fetch_all() |
| 485 | + query = prices.select() |
| 486 | + results = await database.fetch_all(query=query) |
| 487 | + assert len(results) == 1 |
| 488 | + if database_url.startswith("sqlite"): |
| 489 | + # aiosqlite does not support native decimals --> a roud-off error is expected |
| 490 | + assert results[0]["price"] == pytest.approx(price) |
| 491 | + else: |
| 492 | + assert results[0]["price"] == price |
| 493 | + |
| 494 | + |
459 | 495 | @pytest.mark.parametrize("database_url", DATABASE_URLS) |
460 | 496 | @async_adapter |
461 | 497 | async def test_json_field(database_url): |
|
0 commit comments