diff --git a/src/grelu/lightning/__init__.py b/src/grelu/lightning/__init__.py index 5079b9b..fcaa6ce 100644 --- a/src/grelu/lightning/__init__.py +++ b/src/grelu/lightning/__init__.py @@ -860,8 +860,22 @@ def predict_on_dataset( if return_df: if (preds.ndim == 3) and (preds.shape[-1] == 1): + n_tasks_pred = preds.shape[-2] + task_names = ( + self.data_params.get("tasks", {}).get("name", None) + ) + if task_names is None or n_tasks_pred != len(task_names): + if task_names is not None: + warnings.warn( + f"Prediction has {n_tasks_pred} task(s) but the model" + f" has {len(task_names)} task name(s), likely due to a" + " prediction transform. Using generic column names." + ) + task_names = [ + f"task_{i}" for i in range(n_tasks_pred) + ] preds = pd.DataFrame( - preds.squeeze(-1), columns=self.data_params["tasks"]["name"] + preds.squeeze(-1), columns=task_names ) else: warnings.warn( diff --git a/tests/test_lightning.py b/tests/test_lightning.py index d32184a..6db7dd4 100644 --- a/tests/test_lightning.py +++ b/tests/test_lightning.py @@ -256,6 +256,11 @@ def test_lightning_model_transform(): preds = multitask_profile_model.predict_on_dataset(udataset) assert preds.shape == (2, 1, 1) + # return_df=True should work after transform reduces tasks + preds_df = multitask_profile_model.predict_on_dataset(udataset, return_df=True) + assert isinstance(preds_df, pd.DataFrame) + assert preds_df.shape == (2, 1) + # Remove multitask_profile_model.reset_transform() preds = multitask_profile_model.predict_on_dataset(udataset)