try:
from pyspark import SparkConf
except ImportError:
! pip install pyspark==3.2.1
from pyspark import SparkConf
from pyspark.sql import SparkSession, types as st
from IPython.display import HTML
import spark.helpers as sh
# Setup Spark
conf = SparkConf().setMaster("local[1]").setAppName("examples")
spark = SparkSession.builder.config(conf=conf).getOrCreate()
spark.sparkContext.setLogLevel('ERROR')
# Load example datasets
dataframe_1 = spark.read.options(header=True).csv("./data/dataset_1.csv")
dataframe_2 = spark.read.options(header=True).csv("./data/dataset_2.csv")
html = (
"<div style='float:left'><h4>Dataset 1:</h3>" +
dataframe_1.toPandas().to_html() +
"</div><div style='float:left; margin-left:50px;'><h4>Dataset 2:</h3>" +
dataframe_2.toPandas().to_html() +
"</div>"
)
HTML(html)
x1 | x2 | x3 | x4 | x5 | |
---|---|---|---|---|---|
0 | A | J | 734 | 499 | 595.0 |
1 | B | J | 357 | 202 | 525.0 |
2 | C | H | 864 | 568 | 433.5 |
3 | D | J | 530 | 703 | 112.3 |
4 | E | H | 61 | 521 | 906.0 |
5 | F | H | 482 | 496 | 13.0 |
6 | G | A | 350 | 279 | 941.0 |
7 | H | C | 171 | 267 | 423.0 |
8 | I | C | 755 | 133 | 600.0 |
9 | J | A | 228 | 765 | 7.0 |
x1 | x3 | x4 | x6 | x7 | |
---|---|---|---|---|---|
0 | W | K | 391 | 140 | 872.0 |
1 | X | G | 88 | 483 | 707.1 |
2 | Y | M | 144 | 476 | 714.3 |
3 | Z | J | 896 | 68 | 902.0 |
4 | A | O | 946 | 187 | 431.0 |
5 | B | P | 692 | 523 | 503.5 |
6 | C | Q | 550 | 988 | 181.05 |
7 | D | R | 50 | 419 | 42.0 |
8 | E | S | 824 | 805 | 558.2 |
9 | F | T | 69 | 722 | 721.0 |
for group, data in sh.group_iterator(dataframe_1, "x2"):
print(group, " => ", data.toPandas().shape[0])
A => 2
C => 2
H => 3
J => 3
for group, data in sh.group_iterator(dataframe_1, ["x1", "x2"]):
print(group, " => ", data.toPandas().shape[0])
('A', 'J') => 1
('B', 'J') => 1
('C', 'H') => 1
('D', 'J') => 1
('E', 'H') => 1
('F', 'H') => 1
('G', 'A') => 1
('H', 'C') => 1
('I', 'C') => 1
('J', 'A') => 1
before = [(x["name"], x["type"]) for x in dataframe_1.schema.jsonValue()["fields"]]
schema = {
"x2": st.IntegerType(),
"x5": st.FloatType(),
}
new_dataframe = sh.change_schema(dataframe_1, schema)
after = [(x["name"], x["type"]) for x in new_dataframe.schema.jsonValue()["fields"]]
check = [
('x1', 'string'),
('x2', 'integer'),
('x3', 'string'),
('x4', 'string'),
('x5', 'float')
]
assert before != after
assert after == check
joined = sh.join(dataframe_1.select("x2", "x5"), dataframe_2, sh.JoinStatement("x2", "x1"))
joined.toPandas()
x1 | x2 | x3 | x4 | x5 | x6 | x7 | |
---|---|---|---|---|---|---|---|
0 | A | A | O | 946 | 7.0 | 187 | 431.0 |
1 | A | A | O | 946 | 941.0 | 187 | 431.0 |
2 | C | C | Q | 550 | 600.0 | 988 | 181.05 |
3 | C | C | Q | 550 | 423.0 | 988 | 181.05 |
try:
joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"))
except ValueError as error:
print(f"Error raised as expected: {error}")
joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="left")
joined.toPandas()
Error raised as expected:
Overlapping columns found in the dataframes: ['x1', 'x3', 'x4']
Please provide the `overwrite_strategy` argument therefore, to select a selection strategy:
* "left": Use all the intersecting columns from the left dataframe
* "right": Use all the intersecting columns from the right dataframe
* [["x_in_left", "y_in_left"], ["z_in_right"]]: Provide column names for both
x1 | x2 | x3 | x4 | x5 | x6 | x7 | |
---|---|---|---|---|---|---|---|
0 | A | J | 734 | 499 | 595.0 | 187 | 431.0 |
1 | B | J | 357 | 202 | 525.0 | 523 | 503.5 |
2 | C | H | 864 | 568 | 433.5 | 988 | 181.05 |
3 | D | J | 530 | 703 | 112.3 | 419 | 42.0 |
4 | E | H | 61 | 521 | 906.0 | 805 | 558.2 |
5 | F | H | 482 | 496 | 13.0 | 722 | 721.0 |
joined = sh.join(dataframe_1, dataframe_2, sh.JoinStatement("x1"), overwrite_strategy="right")
joined.toPandas()
x1 | x2 | x3 | x4 | x5 | x6 | x7 | |
---|---|---|---|---|---|---|---|
0 | A | J | O | 946 | 595.0 | 187 | 431.0 |
1 | B | J | P | 692 | 525.0 | 523 | 503.5 |
2 | C | H | Q | 550 | 433.5 | 988 | 181.05 |
3 | D | J | R | 50 | 112.3 | 419 | 42.0 |
4 | E | H | S | 824 | 906.0 | 805 | 558.2 |
5 | F | H | T | 69 | 13.0 | 722 | 721.0 |
joined = sh.join(
dataframe_1, dataframe_2, sh.JoinStatement("x1"),
overwrite_strategy=[["x1", "x3"], ["x4"]]
)
joined.toPandas()
x1 | x2 | x3 | x4 | x5 | x6 | x7 | |
---|---|---|---|---|---|---|---|
0 | A | J | 734 | 946 | 595.0 | 187 | 431.0 |
1 | B | J | 357 | 692 | 525.0 | 523 | 503.5 |
2 | C | H | 864 | 550 | 433.5 | 988 | 181.05 |
3 | D | J | 530 | 50 | 112.3 | 419 | 42.0 |
4 | E | H | 61 | 824 | 906.0 | 805 | 558.2 |
5 | F | H | 482 | 69 | 13.0 | 722 | 721.0 |
x1_x1 = sh.JoinStatement("x1")
x1_x3 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x3, "or")
joined = sh.join(dataframe_1, dataframe_2, statement, overwrite_strategy="left")
joined.toPandas()
x1 | x2 | x3 | x4 | x5 | x6 | x7 | |
---|---|---|---|---|---|---|---|
0 | A | J | 734 | 499 | 595.0 | 187 | 431.0 |
1 | B | J | 357 | 202 | 525.0 | 523 | 503.5 |
2 | C | H | 864 | 568 | 433.5 | 988 | 181.05 |
3 | D | J | 530 | 703 | 112.3 | 419 | 42.0 |
4 | E | H | 61 | 521 | 906.0 | 805 | 558.2 |
5 | F | H | 482 | 496 | 13.0 | 722 | 721.0 |
6 | G | A | 350 | 279 | 941.0 | 483 | 707.1 |
7 | J | A | 228 | 765 | 7.0 | 68 | 902.0 |
(Perform sequential joins instead)
x1_x1 = sh.JoinStatement("x1")
x1_x2 = sh.JoinStatement("x1", "x3")
statement = sh.JoinStatement(x1_x1, x1_x2, "or")
statement_complex = sh.JoinStatement(statement, statement, "and")
try:
joined = sh.join(dataframe_1, dataframe_2, statement_complex, overwrite_strategy="left")
except NotImplementedError as error:
print(f"Error raised as expected: [{error}]")
Error raised as expected: [Recursive JoinStatement not implemented]