Skip to content

Spark SQL

Spark SQL 提供了用于处理结构化数据的高级 API。

概述

Spark SQL 支持使用 SQL 查询和 DataFrame API 来处理结构化数据。

DataFrame 操作

创建 DataFrame

python
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("SparkSQLExample") \
    .getOrCreate()

# 从 CSV 文件创建
df = spark.read.csv("data.csv", header=True, inferSchema=True)

# 从 JSON 文件创建
df = spark.read.json("data.json")

# 从列表创建
data = [("Alice", 25), ("Bob", 30)]
df = spark.createDataFrame(data, ["name", "age"])

数据查看

python
# 显示前几行
df.show()

# 查看数据信息
df.printSchema()

# 查看统计摘要
df.describe().show()

数据查询

python
# 选择列
df.select("name", "age").show()

# 过滤数据
df.filter(df["age"] > 18).show()

# 排序
df.orderBy(df["age"].desc()).show()

# 分组聚合
df.groupBy("gender").count().show()

SQL 查询

python
# 创建临时视图
df.createOrReplaceTempView("users")

# 执行 SQL
result = spark.sql("SELECT name, age FROM users WHERE age > 18")
result.show()

Dataset API

python
from pyspark.sql import Row

# 创建 Dataset
ds = spark.createDataset([Row(name="Alice", age=25), Row(name="Bob", age=30)])

# 使用 lambda 表达式
ds.filter(lambda x: x.age > 18).show()

# 使用类型安全的操作
ds.select(ds.name, ds.age + 1).show()

性能优化

缓存 DataFrame

python
df.cache()
df.persist()

广播变量

python
from pyspark.sql.functions import broadcast

# 广播小表
small_df = spark.read.csv("small.csv")
large_df = spark.read.csv("large.csv")

# 使用广播连接
result = large_df.join(broadcast(small_df), "key")

分区优化

python
# 按列分区
df.write.partitionBy("year", "month").parquet("output")

# 读取分区数据
df = spark.read.parquet("output").where("year = 2024")

UDF

创建 UDF

python
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

# 定义函数
def greet(name):
    return f"Hello, {name}!"

# 注册 UDF
greet_udf = udf(greet, StringType())

# 使用 UDF
df.withColumn("greeting", greet_udf(df["name"])).show()

矢量化 UDF

python
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(StringType())
def greet_pandas(names: pd.Series) -> pd.Series:
    return names.apply(lambda x: f"Hello, {x}!")

df.withColumn("greeting", greet_pandas(df["name"])).show()