[For less verbose and foolproof operations]
try :
from pyspark import SparkConf
except ImportError :
! pip install pyspark == 3.2 .1
from pyspark import SparkConf
from pyspark .sql import SparkSession , types as st
from IPython .display import HTML
import spark .helpers as sh
# Setup Spark
conf = SparkConf ().setMaster ("local[1]" ).setAppName ("examples" )
spark = SparkSession .builder .config (conf = conf ).getOrCreate ()
spark .sparkContext .setLogLevel ('ERROR' )
# Load example datasets
dataframe_1 = spark .read .options (header = True ).csv ("./data/dataset_1.csv" )
dataframe_2 = spark .read .options (header = True ).csv ("./data/dataset_2.csv" )
html = (
"<div style='float:left'><h4>Dataset 1:</h3>" +
dataframe_1 .toPandas ().to_html () +
"</div><div style='float:left; margin-left:50px;'><h4>Dataset 2:</h3>" +
dataframe_2 .toPandas ().to_html () +
"</div>"
)
HTML (html )
x1
x2
x3
x4
x5
0
A
J
734
499
595.0
1
B
J
357
202
525.0
2
C
H
864
568
433.5
3
D
J
530
703
112.3
4
E
H
61
521
906.0
5
F
H
482
496
13.0
6
G
A
350
279
941.0
7
H
C
171
267
423.0
8
I
C
755
133
600.0
9
J
A
228
765
7.0
x1
x3
x4
x6
x7
0
W
K
391
140
872.0
1
X
G
88
483
707.1
2
Y
M
144
476
714.3
3
Z
J
896
68
902.0
4
A
O
946
187
431.0
5
B
P
692
523
503.5
6
C
Q
550
988
181.05
7
D
R
50
419
42.0
8
E
S
824
805
558.2
9
F
T
69
722
721.0
for group , data in sh .group_iterator (dataframe_1 , "x2" ):
print (group , " => " , data .toPandas ().shape [0 ])
A => 2
C => 2
H => 3
J => 3
[Multiple columns group by]
for group , data in sh .group_iterator (dataframe_1 , ["x1" , "x2" ]):
print (group , " => " , data .toPandas ().shape [0 ])
('A', 'J') => 1
('B', 'J') => 1
('C', 'H') => 1
('D', 'J') => 1
('E', 'H') => 1
('F', 'H') => 1
('G', 'A') => 1
('H', 'C') => 1
('I', 'C') => 1
('J', 'A') => 1
before = [(x ["name" ], x ["type" ]) for x in dataframe_1 .schema .jsonValue ()["fields" ]]
schema = {
"x2" : st .IntegerType (),
"x5" : st .FloatType (),
}
new_dataframe = sh .change_schema (dataframe_1 , schema )
after = [(x ["name" ], x ["type" ]) for x in new_dataframe .schema .jsonValue ()["fields" ]]
check = [
('x1' , 'string' ),
('x2' , 'integer' ),
('x3' , 'string' ),
('x4' , 'string' ),
('x5' , 'float' )
]
assert before != after
assert after == check
joined = sh .join (dataframe_1 .select ("x2" , "x5" ), dataframe_2 , sh .JoinStatement ("x2" , "x1" ))
joined .toPandas ()
x1
x2
x3
x4
x5
x6
x7
0
A
A
O
946
7.0
187
431.0
1
A
A
O
946
941.0
187
431.0
2
C
C
Q
550
600.0
988
181.05
3
C
C
Q
550
423.0
988
181.05
[When there are overlapping columns]
try :
joined = sh .join (dataframe_1 , dataframe_2 , sh .JoinStatement ("x1" ))
except ValueError as error :
print (f"Error raised as expected: { error } " )
joined = sh .join (dataframe_1 , dataframe_2 , sh .JoinStatement ("x1" ), overwrite_strategy = "left" )
joined .toPandas ()
Error raised as expected:
Overlapping columns found in the dataframes: ['x1', 'x3', 'x4']
Please provide the `overwrite_strategy` argument therefore, to select a selection strategy:
* "left": Use all the intersecting columns from the left dataframe
* "right": Use all the intersecting columns from the right dataframe
* [["x_in_left", "y_in_left"], ["z_in_right"]]: Provide column names for both
x1
x2
x3
x4
x5
x6
x7
0
A
J
734
499
595.0
187
431.0
1
B
J
357
202
525.0
523
503.5
2
C
H
864
568
433.5
988
181.05
3
D
J
530
703
112.3
419
42.0
4
E
H
61
521
906.0
805
558.2
5
F
H
482
496
13.0
722
721.0
[Keeping the duplicate columns from the right dataframe]
joined = sh .join (dataframe_1 , dataframe_2 , sh .JoinStatement ("x1" ), overwrite_strategy = "right" )
joined .toPandas ()
x1
x2
x3
x4
x5
x6
x7
0
A
J
O
946
595.0
187
431.0
1
B
J
P
692
525.0
523
503.5
2
C
H
Q
550
433.5
988
181.05
3
D
J
R
50
112.3
419
42.0
4
E
H
S
824
906.0
805
558.2
5
F
H
T
69
13.0
722
721.0
[Keeping the duplicate columns from both]
joined = sh .join (
dataframe_1 , dataframe_2 , sh .JoinStatement ("x1" ),
overwrite_strategy = [["x1" , "x3" ], ["x4" ]]
)
joined .toPandas ()
x1
x2
x3
x4
x5
x6
x7
0
A
J
734
946
595.0
187
431.0
1
B
J
357
692
525.0
523
503.5
2
C
H
864
550
433.5
988
181.05
3
D
J
530
50
112.3
419
42.0
4
E
H
61
824
906.0
805
558.2
5
F
H
482
69
13.0
722
721.0
x1_x1 = sh .JoinStatement ("x1" )
x1_x3 = sh .JoinStatement ("x1" , "x3" )
statement = sh .JoinStatement (x1_x1 , x1_x3 , "or" )
joined = sh .join (dataframe_1 , dataframe_2 , statement , overwrite_strategy = "left" )
joined .toPandas ()
x1
x2
x3
x4
x5
x6
x7
0
A
J
734
499
595.0
187
431.0
1
B
J
357
202
525.0
523
503.5
2
C
H
864
568
433.5
988
181.05
3
D
J
530
703
112.3
419
42.0
4
E
H
61
521
906.0
805
558.2
5
F
H
482
496
13.0
722
721.0
6
G
A
350
279
941.0
483
707.1
7
J
A
228
765
7.0
68
902.0
[Further nested joins are not supported]
(Perform sequential joins instead)
x1_x1 = sh .JoinStatement ("x1" )
x1_x2 = sh .JoinStatement ("x1" , "x3" )
statement = sh .JoinStatement (x1_x1 , x1_x2 , "or" )
statement_complex = sh .JoinStatement (statement , statement , "and" )
try :
joined = sh .join (dataframe_1 , dataframe_2 , statement_complex , overwrite_strategy = "left" )
except NotImplementedError as error :
print (f"Error raised as expected: [{ error } ]" )
Error raised as expected: [Recursive JoinStatement not implemented]