Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/datacustomcode/io/reader/query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
logger = logging.getLogger(__name__)


SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {}"
SQL_QUERY_TEMPLATE: Final = "SELECT * FROM {} LIMIT {}"
PANDAS_TYPE_MAPPING = {
"object": StringType(),
"int64": LongType(),
Expand Down Expand Up @@ -85,29 +85,40 @@ def __init__(self, spark: SparkSession) -> None:
)

def read_dlo(
self, name: str, schema: Union[AtomicType, StructType, str, None] = None
self,
name: str,
schema: Union[AtomicType, StructType, str, None] = None,
row_limit: int = 1000,
) -> PySparkDataFrame:
"""
Read a Data Lake Object (DLO) from the Data Cloud.
Read a Data Lake Object (DLO) from the Data Cloud, limited to a number of rows.

Args:
name (str): The name of the DLO.
schema (Optional[Union[AtomicType, StructType, str]]): Schema of the DLO.
row_limit (int): Maximum number of rows to fetch.

Returns:
PySparkDataFrame: The PySpark DataFrame.
"""
pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name))
pandas_df = self._conn.get_pandas_dataframe(
SQL_QUERY_TEMPLATE.format(name, row_limit)
)
if not schema:
# auto infer schema
schema = _pandas_to_spark_schema(pandas_df)
spark_dataframe = self.spark.createDataFrame(pandas_df, schema)
return spark_dataframe

def read_dmo(
self, name: str, schema: Union[AtomicType, StructType, str, None] = None
self,
name: str,
schema: Union[AtomicType, StructType, str, None] = None,
row_limit: int = 1000,
) -> PySparkDataFrame:
pandas_df = self._conn.get_pandas_dataframe(SQL_QUERY_TEMPLATE.format(name))
pandas_df = self._conn.get_pandas_dataframe(
SQL_QUERY_TEMPLATE.format(name, row_limit)
)
if not schema:
# auto infer schema
schema = _pandas_to_spark_schema(pandas_df)
Expand Down
49 changes: 49 additions & 0 deletions src/datacustomcode/io/writer/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,67 @@

from pyspark.sql import DataFrame as PySparkDataFrame

from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode


class PrintDataCloudWriter(BaseDataCloudWriter):
CONFIG_NAME = "PrintDataCloudWriter"

def validate_dataframe_columns_against_dlo(
self,
dataframe: PySparkDataFrame,
dlo_name: str,
reader: QueryAPIDataCloudReader,
) -> None:
"""
Validates that all columns in the given dataframe exist in the DLO schema.

Args:
dataframe (PySparkDataFrame): The DataFrame to validate.
dlo_name (str): The name of the DLO to check against.
reader (QueryAPIDataCloudReader): The reader to use for schema retrieval.

Raises:
ValueError: If any columns in the dataframe are not present in the DLO
schema.
"""
# Get DLO schema (no data, just schema)
dlo_df = reader.read_dlo(dlo_name, row_limit=0)
dlo_columns = set(dlo_df.columns)
df_columns = set(dataframe.columns)

# Find columns in dataframe not present in DLO
extra_columns = df_columns - dlo_columns
if extra_columns:
raise ValueError(
"The following columns are not present in the \n"
f"DLO '{dlo_name}': {sorted(extra_columns)}.\n"
"To fix this error, you can either:\n"
" - Drop these columns from your DataFrame before writing, e.g.,\n"
" dataframe = dataframe.drop({cols})\n"
" - Or, add these columns to the DLO schema in Data Cloud.".format(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be very helpful!

cols=sorted(extra_columns)
)
)

def write_to_dlo(
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
) -> None:

# Instantiate the reader
reader = QueryAPIDataCloudReader(self.spark)

# Validate columns before proceeding
self.validate_dataframe_columns_against_dlo(dataframe, name, reader)

dataframe.show()

def write_to_dmo(
self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode
) -> None:
# The way its validating for DLO and dataframes columns,
# its not going to work for DMO because DMO may not exists,
# so just show the dataframe.

dataframe.show()
55 changes: 5 additions & 50 deletions src/datacustomcode/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ast
import os
import sys
from typing import (
Any,
ClassVar,
Expand All @@ -40,6 +41,8 @@
},
}

STANDARD_LIBS = set(sys.stdlib_module_names)


class DataAccessLayerCalls(pydantic.BaseModel):
read_dlo: frozenset[str]
Expand Down Expand Up @@ -137,54 +140,6 @@ def found(self) -> DataAccessLayerCalls:
class ImportVisitor(ast.NodeVisitor):
"""AST Visitor that extracts external package imports from Python code."""

# Standard library modules that should be excluded from requirements
STANDARD_LIBS: ClassVar[set[str]] = {
"abc",
"argparse",
"ast",
"asyncio",
"base64",
"collections",
"configparser",
"contextlib",
"copy",
"csv",
"datetime",
"enum",
"functools",
"glob",
"hashlib",
"http",
"importlib",
"inspect",
"io",
"itertools",
"json",
"logging",
"math",
"os",
"pathlib",
"pickle",
"random",
"re",
"shutil",
"site",
"socket",
"sqlite3",
"string",
"subprocess",
"sys",
"tempfile",
"threading",
"time",
"traceback",
"typing",
"uuid",
"warnings",
"xml",
"zipfile",
}

# Additional packages to exclude from requirements.txt
EXCLUDED_PACKAGES: ClassVar[set[str]] = {
"datacustomcode", # Internal package
Expand All @@ -200,7 +155,7 @@ def visit_Import(self, node: ast.Import) -> None:
# Get the top-level package name
package = name.name.split(".")[0]
if (
package not in self.STANDARD_LIBS
package not in STANDARD_LIBS
and package not in self.EXCLUDED_PACKAGES
and not package.startswith("_")
):
Expand All @@ -213,7 +168,7 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
# Get the top-level package
package = node.module.split(".")[0]
if (
package not in self.STANDARD_LIBS
package not in STANDARD_LIBS
and package not in self.EXCLUDED_PACKAGES
and not package.startswith("_")
):
Expand Down
8 changes: 4 additions & 4 deletions tests/io/reader/test_query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_read_dlo(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dlo")
SQL_QUERY_TEMPLATE.format("test_dlo", 1000)
)

# Verify DataFrame was created with auto-inferred schema
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_read_dlo_with_schema(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dlo")
SQL_QUERY_TEMPLATE.format("test_dlo", 1000)
)

# Verify DataFrame was created with provided schema
Expand All @@ -192,7 +192,7 @@ def test_read_dmo(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dmo")
SQL_QUERY_TEMPLATE.format("test_dmo", 1000)
)

# Verify DataFrame was created
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_read_dmo_with_schema(

# Verify get_pandas_dataframe was called with the right SQL
mock_connection.get_pandas_dataframe.assert_called_once_with(
SQL_QUERY_TEMPLATE.format("test_dmo")
SQL_QUERY_TEMPLATE.format("test_dmo", 1000)
)

# Verify DataFrame was created with provided schema
Expand Down
Loading