diff --git a/src/synthesized_datasets/_datasets.py b/src/synthesized_datasets/_datasets.py index ab11602..802a1c9 100644 --- a/src/synthesized_datasets/_datasets.py +++ b/src/synthesized_datasets/_datasets.py @@ -62,8 +62,8 @@ def load_spark(self, spark: _typing.Optional[_ps.SparkSession] = None) -> _ps.Da spark = _ps.SparkSession.builder.getOrCreate() spark.sparkContext.addFile(self.url) - _, ext = _os.path.splitext(self.url) - df = spark.read.csv(_SparkFiles.get("".join([self.name, ext])), header=True, inferSchema=True) + _, filename = _os.path.split(self.url) + df = spark.read.csv(_SparkFiles.get(filename), header=True, inferSchema=True) df.name = self.name return df