摘要
apache arrow是用于在jvm和python进程之间进行高效数据传输的列式数据格式。
使用arrow
须首先将基于arrow的数据传输设置为可用。
import numpy as np
import pandas as pd
# 使用arrow进行数据传输
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
pdf = pd.DataFrame(np.random.rand(100, 3))
df = spark.createDataFrame(pdf)
result_pdf = df.select("*").toPandas()
pandas udf(向量化udf)
pandas udf基于arrow进行数据传输,基于pandas进行数据计算,支持向量化操作。使用pandas_udf装饰器即可将用户自定义函数封装为一个pandas udf。在spark3.0以后,推荐使用python 类型提示(type hints)进行pandas udf的定义。
通常情况下,输入或输出类型应为pandas.Series,有一种情况例外,那就是当输入或输出列是一个structType时,此时类型为pandas.DataFrame。
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf("col1 string, col2 long")
def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:
s3[‘col2‘] = s1 + s2.str.len()
return s3
# Create a Spark DataFrame that has three columns including a struct column.
df = spark.createDataFrame(
[[1, "a string", ("a nested string",)]],
"long_col long, string_col string, struct_col struct<col1:string>")
df.printSchema()
# root
# |-- long_column: long (nullable = true)
# |-- string_column: string (nullable = true)
# |-- struct_column: struct (nullable = true)
# | |-- col1: string (nullable = true)
df.select(func("long_col", "string_col", "struct_col")).printSchema()
# |-- func(long_col, string_col, struct_col): struct (nullable = true)
# | |-- col1: string (nullable = true)
# | |-- col2: long (nullable = true)
Series to Series
在内部,spark通过将column分为多个batch来执行pands udf,完成计算后,将结果进行拼接。
以下pandas udf实现两列相乘。
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType
def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series:
return a * b
multiply = pandas_udf(multiply_func, returnType=LongType())
x = pd.Series([1, 2, 3])
print(multiply_func(x, x))
# 0 1
# 1 4
# 2 9
# dtype: int64
df = spark.createDataFrame(pd.DataFrame(x, columns=["x"]))
df.select(multiply(col("x"), col("x"))).show()
# +-------------------+
# |multiply_func(x, x)|
# +-------------------+
# | 1|
# | 4|
# | 9|
# +-------------------+
@pandas_udf("long")
def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
# Do some expensive initialization with a state
state = very_expensive_initialization()
for x in iterator:
# Use that state for whole iterator.
yield calculate_with_state(x, state)
df.select(calculate("value")).show()
from typing import Iterator
import pandas as pd
from pyspark.sql.functions import pandas_udf
pdf = pd.DataFrame([1, 2, 3], columns=["x"])
df = spark.createDataFrame(pdf)
# Declare the function and create the UDF
@pandas_udf("long")
def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
for x in iterator:
yield x + 1
df.select(plus_one("x")).show()
# +-----------+
# |plus_one(x)|
# +-----------+
# | 2|
# | 3|
# | 4|
# +-----------+
from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf
pdf = pd.DataFrame([1, 2, 3], columns=["x"])
df = spark.createDataFrame(pdf)
# Declare the function and create the UDF
@pandas_udf("long")
def multiply_two_cols(
iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
for a, b in iterator:
yield a * b
df.select(multiply_two_cols("x", "x")).show()
# +-----------------------+
# |multiply_two_cols(x, x)|
# +-----------------------+
# | 1|
# | 4|
# | 9|
# +-----------------------+
import pandas as pd
from pyspark.sql.functions import pandas_udf
from pyspark.sql import Window
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
# Declare the function and create the UDF
@pandas_udf("double")
def mean_udf(v: pd.Series) -> float:
return v.mean()
df.select(mean_udf(df[‘v‘])).show()
# +-----------+
# |mean_udf(v)|
# +-----------+
# | 4.2|
# +-----------+
df.groupby("id").agg(mean_udf(df[‘v‘])).show()
# +---+-----------+
# | id|mean_udf(v)|
# +---+-----------+
# | 1| 1.5|
# | 2| 6.0|
# +---+-----------+
w = Window .partitionBy(‘id‘) .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df.withColumn(‘mean_v‘, mean_udf(df[‘v‘]).over(w)).show()
# +---+----+------+
# | id| v|mean_v|
# +---+----+------+
# | 1| 1.0| 1.5|
# | 1| 2.0| 1.5|
# | 2| 3.0| 6.0|
# | 2| 5.0| 6.0|
# | 2|10.0| 6.0|
# +---+----+------+
在pyspark中使用pandas udf/apache Arrow
原文:https://www.cnblogs.com/zcsh/p/14840836.html