Map-Side Join in Spark

Join of two or more data sets is one of the most widely used operations you do with your data, but in distributed systems it can be a huge headache. In general, since your data are distributed among many nodes, they have to be shuffled before a join that causes significant network I/O and slow performance.

Fortunately, if you need to join a large table (fact) with relatively small tables (dimensions) i.e. to perform a star-schema join you can avoid sending all data of the large table over the network. This type of join is called map-side join in Hadoop community. In other distributed systems, it is often called replicated or broadcast join.

Let’s use the following sample data (one fact and two dimension tables):

// Fact table
val flights = sc.parallelize(List(
  ("SEA", "JFK", "DL", "418",  "7:00"),
  ("SFO", "LAX", "AA", "1250", "7:05"),
  ("SFO", "JFK", "VX", "12",   "7:05"),
  ("JFK", "LAX", "DL", "424",  "7:10"),
  ("LAX", "SEA", "DL", "5737", "7:10")))  
  
// Dimension table
val airports = sc.parallelize(List(
  ("JFK", "John F. Kennedy International Airport", "New York", "NY"),
  ("LAX", "Los Angeles International Airport", "Los Angeles", "CA"),
  ("SEA", "Seattle-Tacoma International Airport", "Seattle", "WA"),
  ("SFO", "San Francisco International Airport", "San Francisco", "CA")))
  
// Dimension table
val airlines = sc.parallelize(List(
  ("AA", "American Airlines"), 
  ("DL", "Delta Airlines"), 
  ("VX", "Virgin America")))   

We need to join the fact and dimension tables to get the following result:

Seattle           New York       Delta Airlines       418   7:00
San Francisco     Los Angeles    American Airlines    1250  7:05
San Francisco     New York       Virgin America       12    7:05
New York          Los Angeles    Delta Airlines       424   7:10
Los Angeles       Seattle        Delta Airlines       5737  7:10

The fact table be very large, while dimension tables are often quite small. Let’s download the dimension tables to the Spark driver, create maps and broadcast them to each worker node:

val airportsMap = sc.broadcast(airports.map{case(a, b, c, d) => (a, c)}.collectAsMap)
val airlinesMap = sc.broadcast(airlines.collectAsMap)

Now you can run the map-side join:

flights.map{case(a, b, c, d, e) => 
   (airportsMap.value.get(a).get, 
    airportsMap.value.get(b).get, 
    airlinesMap.value.get(c).get, d, e)}.collect

The result of the execution (formatted):

res: Array[(String, String, String, String, String)] = Array(
  (Seattle, New York, Delta Airlines, 418, 7:00), 
  (San Francisco, Los Angeles, American Airlines, 1250, 7:05), 
  (San Francisco, New York, Virgin America, 12, 7:05), 
  (New York, Los Angeles, Delta Airlines, 424, 7:10), 
  (Los Angeles, Seattle, Delta Airlines, 5737, 7:10))

How it Works

First we created a RDD for each table. airports and airlines are dimension tables that we are going to use in map-side join, so we converted them to a map and broadcast to each execution node. Note that we extracted only 2 columns from airports table.

Then we just used map function for each row of flights table, and retrieved dimension values from airportsMap and airlinesMap. If flights table is very large, map function will be executed concurrently for each partition that has own copy of airportsMap and airlinesMap maps.

This approach allows us not to shuffle the fact table, and to get quite good join performance.

Leave a Reply