diff --git a/scikit_longitudinal/data_preparation/elsa_handler.py b/scikit_longitudinal/data_preparation/elsa_handler.py index e05dd5e..4a5caee 100644 --- a/scikit_longitudinal/data_preparation/elsa_handler.py +++ b/scikit_longitudinal/data_preparation/elsa_handler.py @@ -180,12 +180,15 @@ def save_datasets(self, dir_output: str = "tmp", file_format: str = "csv"): dataset.to_csv(f"{dir_output}/{class_name}_dataset.csv", index=False) elif file_format.lower() == "arff": dataset.fillna("?", inplace=True) - arff.dump( - f"{dir_output}/{class_name}_dataset.arff", - dataset.values, - relation=class_name, - names=dataset.columns, - ) + arff_data = { + 'description': '', + 'relation': class_name, + 'attributes': [(col, 'REAL' if dataset[col].dtype in ['float64', 'int64'] else 'STRING') for col in + dataset.columns], + 'data': dataset.values.tolist() + } + with open(f"{dir_output}/{class_name}_dataset.arff", 'w') as f: + arff.dump(arff_data, f) else: raise ValueError(f"Unsupported file format: {file_format}") diff --git a/scikit_longitudinal/data_preparation/longitudinal_dataset.py b/scikit_longitudinal/data_preparation/longitudinal_dataset.py index 89f770c..b972d34 100644 --- a/scikit_longitudinal/data_preparation/longitudinal_dataset.py +++ b/scikit_longitudinal/data_preparation/longitudinal_dataset.py @@ -485,12 +485,13 @@ def convert(self, output_path: Union[str, Path]) -> None: # pragma: no cover if file_ext == ".arff": arff_data = self._csv_to_arff(self._data, self.file_path.stem) - arff.dump( - output_path, - arff_data["data"], - relation=arff_data["relation"], - names=arff_data["attributes"], - ) + with open(output_path, 'w') as f: + arff.dump({ + 'description': '', + 'relation': arff_data['relation'], + 'attributes': arff_data['attributes'], + 'data': arff_data['data'] + }, f) elif file_ext == ".csv": self._data.to_csv(output_path, index=False, na_rep="") else: