From 3e6124952fcb778eb19a784096a644f20d02b266 Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Wed, 18 Mar 2026 23:47:08 -0500 Subject: [PATCH 1/2] Fix return_df after transform reduces number of tasks predict_on_dataset(return_df=True) raised a ValueError when a prediction transform (e.g. Specificity) reduced the number of tasks, because the DataFrame column names came from the original task list. Use the actual prediction shape and fall back to generic column names when the task count no longer matches. Closes #175 --- src/grelu/lightning/__init__.py | 13 ++++++++++++- tests/test_lightning.py | 5 +++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/grelu/lightning/__init__.py b/src/grelu/lightning/__init__.py index 7159fde..6dc8f92 100644 --- a/src/grelu/lightning/__init__.py +++ b/src/grelu/lightning/__init__.py @@ -793,8 +793,19 @@ def predict_on_dataset( if return_df: if (preds.ndim == 3) and (preds.shape[-1] == 1): + task_names = self.data_params["tasks"]["name"] + n_tasks_pred = preds.shape[-2] + if n_tasks_pred != len(task_names): + warnings.warn( + f"Prediction has {n_tasks_pred} task(s) but the model" + f" has {len(task_names)} task name(s), likely due to a" + " prediction transform. Using generic column names." + ) + task_names = [ + f"task_{i}" for i in range(n_tasks_pred) + ] preds = pd.DataFrame( - preds.squeeze(-1), columns=self.data_params["tasks"]["name"] + preds.squeeze(-1), columns=task_names ) else: warnings.warn( diff --git a/tests/test_lightning.py b/tests/test_lightning.py index d734ba7..835791f 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -256,6 +256,11 @@ def test_lightning_model_transform(): preds = multitask_profile_model.predict_on_dataset(udataset) assert preds.shape == (2, 1, 1) + # return_df=True should work after transform reduces tasks + preds_df = multitask_profile_model.predict_on_dataset(udataset, return_df=True) + assert isinstance(preds_df, pd.DataFrame) + assert preds_df.shape == (2, 1) + # Remove multitask_profile_model.reset_transform() preds = multitask_profile_model.predict_on_dataset(udataset) From 4ffe6c14b42f746bd36ac060e916bb5b493ad190 Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Fri, 24 Apr 2026 18:54:43 -0500 Subject: [PATCH 2/2] Handle missing task names in return_df when data_params lacks tasks Use .get() to safely access data_params["tasks"]["name"] so return_df works on models that have not been fit (e.g. loaded from checkpoint without training data). --- src/grelu/lightning/__init__.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/grelu/lightning/__init__.py b/src/grelu/lightning/__init__.py index cdbb592..fcaa6ce 100644 --- a/src/grelu/lightning/__init__.py +++ b/src/grelu/lightning/__init__.py @@ -860,14 +860,17 @@ def predict_on_dataset( if return_df: if (preds.ndim == 3) and (preds.shape[-1] == 1): - task_names = self.data_params["tasks"]["name"] n_tasks_pred = preds.shape[-2] - if n_tasks_pred != len(task_names): - warnings.warn( - f"Prediction has {n_tasks_pred} task(s) but the model" - f" has {len(task_names)} task name(s), likely due to a" - " prediction transform. Using generic column names." - ) + task_names = ( + self.data_params.get("tasks", {}).get("name", None) + ) + if task_names is None or n_tasks_pred != len(task_names): + if task_names is not None: + warnings.warn( + f"Prediction has {n_tasks_pred} task(s) but the model" + f" has {len(task_names)} task name(s), likely due to a" + " prediction transform. Using generic column names." + ) task_names = [ f"task_{i}" for i in range(n_tasks_pred) ]