⏳
Loading cheatsheet...
Distributed DataFrame operations, transformations/actions, optimizations and Spark SQL patterns.
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
# Create session
spark = SparkSession.builder \
.appName('MyApp') \
.master('local[*]') \
.config('spark.sql.adaptive.enabled', 'true') \
.config('spark.driver.memory', '4g') \
.config('spark.sql.shuffle.partitions', '8') \
.getOrCreate()
# Stop when done
spark.stop()| Config Key | Default | Purpose |
|---|---|---|
| spark.sql.adaptive.enabled | true (3.2+) | Auto-tune shuffle partitions |
| spark.sql.shuffle.partitions | 200 | Shuffle partition count |
| spark.driver.memory | 1g | Driver JVM memory |
| spark.executor.memory | 1g | Executor JVM memory |
| spark.executor.cores | 1 | CPU cores per executor |
| spark.sql.warehouse.dir | /user/hive/warehouse | Default warehouse path |
| spark.serializer | KryoSerializer | Faster serialization than Java |
| spark.dynamicAllocation.enabled | false | Auto-scale executors |
| spark.sql.broadcastTimeout | 300s | Broadcast join timeout |
| spark.sql.autoBroadcastJoinThreshold | 10MB | Auto-broadcast threshold |
# Install (pip / conda)
pip install pyspark
# PySpark with Jupyter
export PYSPARK_DRIVER_PYTHON=jupyter
export PYSPARK_DRIVER_PYTHON_OPTS='notebook'
pyspark --master local[*]
# With specific Spark version
pip install pyspark==3.5.0spark.sql.shuffle.partitions to match your cluster size — the default of 200 creates too many small tasks for local development.# Create RDDs
rdd1 = spark.sparkContext.parallelize([1, 2, 3, 4, 5])
rdd2 = spark.sparkContext.textFile('data.txt')
# From a DataFrame
rdd3 = df.rdd
# Basic actions
rdd.collect() # all elements (use carefully!)
rdd.first() # first element
rdd.take(5) # first N elements
rdd.count() # element count
rdd.reduce(lambda a, b: a + b)
rdd.sum()
rdd.mean()
rdd.stdev()
rdd.min() / rdd.max()# Transformations (lazy — not executed until action)
rdd.map(lambda x: x * 2)
rdd.flatMap(lambda x: x.split()) # 1-to-many
rdd.filter(lambda x: x > 10)
rdd.distinct()
rdd.sortBy(lambda x: x, ascending=False)
rdd.sample(withReplacement=False, fraction=0.1)
# Set operations
rdd1.union(rdd2)
rdd1.intersection(rdd2)
rdd1.subtract(rdd2)
# Key-value operations
kv_rdd = rdd.map(lambda x: (x, 1))
kv_rdd.reduceByKey(lambda a, b: a + b)
kv_rdd.groupByKey() # returns iterable
kv_rdd.sortByKey()
kv_rdd.join(other_kv_rdd)
kv_rdd.leftOuterJoin(other_kv_rdd)
kv_rdd.mapValues(lambda v: v * 2)
kv_rdd.keys()
kv_rdd.values()# Aggregate
zero_value = (0, 0) # (sum, count)
def seq_op(acc, val):
return (acc[0] + val, acc[1] + 1)
def comb_op(acc1, acc2):
return (acc1[0] + acc2[0], acc1[1] + acc2[1])
result = rdd.aggregate(zero_value, seq_op, comb_op)
# Broadcast variables (efficient shared read-only data)
lookup = spark.sparkContext.broadcast({'a': 1, 'b': 2})
rdd.map(lambda x: lookup.value.get(x, 0))
# Accumulators (write-only shared counters)
counter = spark.sparkContext.accumulator(0)
rdd.foreach(lambda x: counter.add(1))
print(counter.value)
# Cache / persist
rdd.persist() # MEMORY_AND_DISK
rdd.unpersist()
rdd.cache() # alias for MEMORY_ONLY| Type | Description | Examples |
|---|---|---|
| Transformation | Lazy — builds DAG | map, filter, flatMap, join, groupByKey |
| Action | Triggers execution | collect, count, reduce, saveAsTextFile |
| Narrow | No data shuffle | map, filter, flatMap |
| Wide | Requires shuffle | groupByKey, reduceByKey, join, distinct |
reduceByKey over groupByKey — it performs local aggregation before shuffling, dramatically reducing network traffic.from pyspark.sql import SparkSession
from pyspark.sql.types import *
# From list of tuples
data = [('Alice', 34, 70000), ('Bob', 45, 80000)]
df = spark.createDataFrame(data, ['name', 'age', 'salary'])
# With explicit schema
schema = StructType([
StructField('name', StringType(), True),
StructField('age', IntegerType(), True),
StructField('salary', DoubleType(), True),
])
df = spark.createDataFrame(data, schema)
# From pandas
import pandas as pd
pdf = pd.DataFrame({'name': ['Alice'], 'age': [34]})
df = spark.createDataFrame(pdf)
# From RDD with Row
from pyspark.sql import Row
rows = [Row(name='Alice', age=34), Row(name='Bob', age=45)]
df = spark.createDataFrame(rows)# Read
df = spark.read.csv('data.csv', header=True, inferSchema=True)
df = spark.read.parquet('data.parquet')
df = spark.read.json('data.json')
df = spark.read.orc('data.orc')
df = spark.read.format('csv').option('header', 'true').load('data/')
# Read options
df = spark.read.csv('data.csv',
header=True,
inferSchema=True,
sep=',',
quote='"',
escape='\\',
nullValue='NA',
dateFormat='yyyy-MM-dd',
timestampFormat='yyyy-MM-dd HH:mm:ss',
multiLine=True,
)
# Write
df.write.csv('output/', header=True, mode='overwrite')
df.write.parquet('output/', mode='overwrite', partitionBy='year')
df.write.json('output/', mode='append')
# Write modes: 'overwrite', 'append', 'ignore', 'error' (default)# Inspect
df.show(5, truncate=False) # show rows
df.printSchema() # schema tree
df.dtypes # column types list
df.columns # column names list
df.describe().show() # summary stats
df.summary().show() # extended stats
df.head(3) # first N as list of Rows
df.first() # first Row
df.count()
df.distinct().count()
# Select, rename
df.select('name', 'age')
df.select(df.name, (df.age + 1).alias('age_plus_1'))
df.select(F.col('name'), F.lit('constant'))
df.withColumnRenamed('name', 'full_name')
df.toDF('id', 'value') # rename all columnsinferSchema=True when reading CSVs — otherwise every column becomes a string and you lose type safety and optimization benefits.from pyspark.sql import functions as F
# Filter / where (equivalent)
df.filter(F.col('age') > 30)
df.where((F.col('age') > 30) & (F.col('salary') > 60000))
df.filter(F.col('name').like('A%')) # SQL LIKE
df.filter(F.col('name').rlike('^[AB].*')) # regex
df.filter(F.col('name').isin('Alice', 'Bob'))
df.filter(F.col('age').isNull())
df.filter(F.col('age').isNotNull())
df.filter(F.col('name').startswith('A'))
df.filter(F.col('name').contains('li'))
# Sort
df.sort('age') # ascending
df.sort(F.col('age').desc())
df.orderBy(F.col('age').desc(), F.col('salary').asc())
df.orderBy(F.desc('age'), F.asc('salary'))# Add / modify columns
df.withColumn('bonus', F.col('salary') * 0.1)
df.withColumn('age_group', F.when(F.col('age') < 30, 'young')
.when(F.col('age') < 50, 'mid')
.otherwise('senior'))
df.withColumn('salary_k', F.round(F.col('salary') / 1000, 1))
df.withColumn('rank', F.row_number().over(Window.orderBy('salary')))
df.withColumn('salary', F.col('salary').cast('double'))
# Drop columns
df.drop('unnecessary_col')
df.drop(F.col('temp1'), F.col('temp2'))
# Handle nulls
df.fillna({'age': 0, 'salary': 50000})
df.na.fill(0, subset=['age'])
df.na.drop(how='any', thresh=3, subset=['name', 'age'])
df.na.replace('N/A', None)from pyspark.sql import Window
# GroupBy aggregations
df.groupBy('department').agg(
F.count('*').alias('count'),
F.avg('salary').alias('avg_salary'),
F.sum('salary').alias('total_salary'),
F.min('age').alias('min_age'),
F.max('age').alias('max_age'),
F.collect_list('name').alias('employees'),
F.countDistinct('role').alias('unique_roles'),
).orderBy(F.col('avg_salary').desc())
# Window functions
w = Window.partitionBy('department').orderBy(F.desc('salary'))
df.withColumn('rank', F.rank().over(w))
df.withColumn('dense_rank', F.dense_rank().over(w))
df.withColumn('row_num', F.row_number().over(w))
df.withColumn('dept_avg', F.avg('salary').over(
Window.partitionBy('department')
))
# Joins
df1.join(df2, on='id', how='inner') # inner, left, right, full, cross
df1.join(df2, df1.id == df2.id) # different column names
df1.join(df2, ['id', 'dept']) # join on multiple columns
# Anti / semi join
df1.join(df2, on='id', how='left_anti') # NOT IN
df1.join(df2, on='id', how='left_semi') # EXISTS# Pivot
df.groupBy('department').pivot('year', [2022, 2023, 2024]) \
.agg(F.sum('revenue'))
# Unpivot (melt) — Spark 3.4+
df.unpivot(
ids=['department'],
values=[F.col('2022'), F.col('2023')],
variableColumnName='year',
valueColumnName='revenue'
)
# Union
df1.union(df2) # all rows (may have duplicates)
df1.unionByName(df2) # match by column name
df1.distinct() # remove duplicates| Function | Description | Notes |
|---|---|---|
| row_number() | Sequential number (1,2,3...) | No ties — unique per partition |
| rank() | Rank with gaps for ties | 1,2,2,4 |
| dense_rank() | Rank without gaps | 1,2,2,3 |
| percent_rank() | Rank as fraction | 0.0 to 1.0 |
| ntile(n) | Bucket into n groups | Quartiles: ntile(4) |
| lag(col, n) | Value n rows before | Access previous rows |
| lead(col, n) | Value n rows after | Access following rows |
| first() / last() | First / last in window | Use ignoreNulls=True |
| aggregate(expr) | Running aggregation | Running sum, avg, etc. |
left_anti join instead of NOT IN subqueries — it handles NULLs correctly and the optimizer can push down predicates.# Register temp view
df.createOrReplaceTempView('employees')
df.createOrReplaceGlobalTempView('global_employees')
# Run SQL
result = spark.sql('''
SELECT department,
AVG(salary) as avg_salary,
COUNT(*) as headcount
FROM employees
WHERE age > 30
GROUP BY department
HAVING COUNT(*) > 5
ORDER BY avg_salary DESC
LIMIT 10
''')
result.show()
# CTE
spark.sql('''
WITH dept_stats AS (
SELECT department, AVG(salary) as avg_sal
FROM employees GROUP BY department
)
SELECT e.name, e.salary, d.avg_sal,
e.salary - d.avg_sal as diff_from_avg
FROM employees e
JOIN dept_stats d ON e.department = d.department
''')# Spark SQL has extensive built-in functions
spark.sql('''
SELECT
-- String functions
UPPER(name) as name_upper,
CONCAT(first_name, ' ', last_name) as full_name,
SUBSTRING(email, 1, POSITION('@' IN email) - 1) as user,
REGEXP_EXTRACT(phone, '\\d{3}', 0) as area_code,
TRIM(BOTH ' ' FROM name) as clean_name,
-- Date functions
CURRENT_DATE() as today,
CURRENT_TIMESTAMP() as now,
DATEDIFF(end_date, start_date) as days_diff,
DATE_ADD(start_date, 30) as plus_30d,
YEAR(order_date) as order_year,
-- Conditional
CASE WHEN salary > 100000 THEN 'high'
WHEN salary > 50000 THEN 'mid'
ELSE 'low' END as salary_band,
-- Aggregate
PERCENTILE(salary, 0.5) as median_sal,
COLLECT_LIST(DISTINCT role) as all_roles
FROM employees
GROUP BY name, first_name, last_name, email,
phone, salary, end_date, start_date, order_date
''')# Catalog operations
spark.catalog.listDatabases()
spark.catalog.listTables()
spark.catalog.listColumns('employees')
# Check query plan
spark.sql('EXPLAIN SELECT * FROM employees').show(truncate=False)
spark.sql('EXPLAIN FORMATTED SELECT * FROM employees').show(truncate=False)
# Cache table
spark.sql('CACHE TABLE employees')
spark.sql('UNCACHE TABLE employees')
# Create table from query
spark.sql('''
CREATE TABLE IF NOT EXISTS high_earners
USING PARQUET
PARTITIONED BY (department)
AS SELECT * FROM employees WHERE salary > 100000
''')EXPLAIN FORMATTED to read query plans — look for BroadcastHashJoin (good for small tables) vs SortMergeJoin (good for large tables).# Read streaming source
stream_df = (spark.readStream
.format('kafka')
.option('kafka.bootstrap.servers', 'localhost:9092')
.option('subscribe', 'events')
.option('startingOffsets', 'latest')
.load()
)
# From file source (CSV, JSON, Parquet)
stream_df = (spark.readStream
.schema(user_schema)
.format('csv')
.option('header', 'true')
.load('input/streaming/')
)
# From Delta Lake
stream_df = (spark.readStream
.format('delta')
.load('delta/events/')
)# Write stream — complete output mode
query = (stream_df
.writeStream
.outputMode('complete') # 'append', 'update', 'complete'
.format('console')
.option('truncate', 'false')
.trigger(processingTime='10 seconds')
.start()
)
# Write to Delta Lake (append mode)
query = (stream_df
.groupBy('window', 'event_type')
.count()
.writeStream
.outputMode('update')
.format('delta')
.option('checkpointLocation', '/checkpoint/events')
.trigger(processingTime='1 minute')
.start('delta/events_agg/')
)
# ForeachBatch sink
def process_batch(batch_df, batch_id):
batch_df.write.mode('append').jdbc(
url='jdbc:postgresql://db:5432/analytics',
table='events', properties=props
)
query = (stream_df
.writeStream
.foreachBatch(process_batch)
.option('checkpointLocation', '/checkpoint/pg')
.start()
)
query.awaitTermination()# Windowed aggregations
from pyspark.sql.functions import window
result = (stream_df
.withWatermark('event_time', '5 minutes') # late data tolerance
.groupBy(
window('event_time', '10 minutes', '5 minutes'), # window + slide
'category'
)
.agg(
F.count('*').alias('event_count'),
F.avg('value').alias('avg_value'),
)
)
result.writeStream \
.outputMode('update') \
.format('console') \
.start()| Mode | Use Case | Supported Ops |
|---|---|---|
| append | New rows only (no aggregation) | Map, filter, flatMap |
| update | Only changed rows | Aggregation without watermark |
| complete | Full result table | Aggregation, watermark required for state |
checkpointLocation for production streaming jobs — it enables fault tolerance by saving the current offset/state so the stream can resume after failure.# Cache — keep in memory
df.cache() # = persist(MEMORY_AND_DISK)
df.persist() # default: MEMORY_AND_DISK
# Storage levels
from pyspark import StorageLevel
df.persist(StorageLevel.MEMORY_ONLY) # fastest, may drop
df.persist(StorageLevel.MEMORY_AND_DISK) # spill to disk
df.persist(StorageLevel.DISK_ONLY) # slow, durable
df.persist(StorageLevel.MEMORY_ONLY_SER) # serialized, compact
df.persist(StorageLevel.MEMORY_AND_DISK_SER)
# Unpersist
df.unpersist()
# Check if cached
df.is_cached # True / False
df.storageLevel# Repartition (full shuffle)
df.repartition(10) # hash-partition to 10
df.repartition('department') # by column
df.repartition(10, 'department') # 10 partitions by column
# Coalesce (reduce partitions, no full shuffle)
df.coalesce(5) # efficient for reducing
# Partition count
df.rdd.getNumPartitions()
# Write partitioned
df.write.partitionBy('year', 'month') \
.mode('overwrite').parquet('output/')
# Bucketing (pre-join optimization)
df.write.bucketBy(32, 'user_id') \
.sortBy('user_id') \
.mode('overwrite') \
.saveAsTable('bucketed_users')# Adaptive Query Execution (AQE) — Spark 3.x
spark.conf.set('spark.sql.adaptive.enabled', 'true')
spark.conf.set('spark.sql.adaptive.coalescePartitions.enabled', 'true')
spark.conf.set('spark.sql.adaptive.skewJoin.enabled', 'true')
# Broadcast join hint (small table)
result = df1.join(
F.broadcast(df2), # broadcasts df2 to all executors
on='key'
)
# Repartition hint
result = df1.hint('repartition', 10).join(df2, on='key')
# Cache hint
result = df1.hint('CACHE').filter(F.col('age') > 30)| Strategy | Condition | Best For |
|---|---|---|
| BroadcastHashJoin | One side < broadcast threshold | Small-dimension + large-fact |
| SortMergeJoin | Both sides large, sorted | Large-large joins |
| ShuffleHashJoin | One side fits in memory | Medium-large joins |
| BroadcastNestedLoopJoin | Fallback, cross join | Very small tables, no equi-join |
coalesce(1) when writing a single output file (e.g., CSV export), but be careful — all data goes to one partition which can cause OOM on large datasets.# spark-submit
spark-submit \
--master yarn \
--deploy-mode cluster \
--num-executors 10 \
--executor-cores 4 \
--executor-memory 8g \
--driver-memory 4g \
--conf spark.sql.adaptive.enabled=true \
--conf spark.sql.shuffle.partitions=80 \
--py-files deps.zip \
--files config.properties \
my_pipeline.py# User-Defined Functions
from pyspark.sql.functions import udf
from pyspark.sql.types import DoubleType
# Classic UDF (slower — Python serialization overhead)
@udf(returnType=DoubleType())
def celsius_to_fahrenheit(c):
return c * 9.0 / 5.0 + 32.0
df = df.withColumn('temp_f', celsius_to_fahrenheit('temp_c'))
# Pandas UDF (much faster — vectorized with Apache Arrow)
import pandas as pd
from pyspark.sql.functions import pandas_udf
@pandas_udf(DoubleType())
def vectorized_convert(s: pd.Series) -> pd.Series:
return s * 9.0 / 5.0 + 32.0
df = df.withColumn('temp_f', vectorized_convert('temp_c'))
# Grouped map Pandas UDF
@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def normalize(pdf: pd.DataFrame) -> pd.DataFrame:
pdf['value'] = (pdf['value'] - pdf['value'].mean()) / pdf['value'].std()
return pdf
df.groupBy('category').apply(normalize)# Testing with local SparkSession
import pytest
from pyspark.sql import SparkSession
@pytest.fixture(scope='session')
def spark():
session = SparkSession.builder \
.master('local[1]') \
.appName('test') \
.config('spark.sql.shuffle.partitions', '2') \
.getOrCreate()
yield session
session.stop()
def test_pipeline(spark):
# Create test data
test_df = spark.createDataFrame(
[(1, 'Alice', 100), (2, 'Bob', 200)],
['id', 'name', 'amount']
)
# Run transformation
result = my_transformation(test_df)
# Assert
assert result.count() == 2
assert result.filter(F.col('amount') > 50).count() == 2| Area | Best Practice |
|---|---|
| Partitioning | Set shuffle.partitions ≈ total cores × 2-3 |
| Memory | Leave 20% headroom; monitor with Spark UI |
| Serialization | Use KryoSerializer; register custom classes |
| UDFs | Prefer Pandas UDFs over Python UDFs (10-100x faster) |
| Files | Use Parquet or Delta Lake (columnar, compressed) |
| Joins | Broadcast small tables; bucket large joined tables |
| Skew | Enable AQE skew join; salt keys for known skew |
| Checkpoint | Always set for streaming; externalize for long jobs |
| Testing | Test with local[1]; use chispa for DataFrame assertions |
| Logging | Use log4j; set spark.driver.log.level in production |
collect() on large DataFrames in production — it brings all data to the driver. Use take(), show(), or write to storage instead.