Map-Side Join in Spark

Join of two or more data sets is one of the most widely used operations you do with your data, but in distributed systems it can be a huge headache. In general, since your data are distributed among many nodes, they have to be shuffled before a join that causes significant network I/O and slow performance.

Fortunately, if you need to join a large table (fact) with relatively small tables (dimensions) i.e. to perform a star-schema join you can avoid sending all data of the large table over the network. This type of join is called map-side join in Hadoop community. In other distributed systems, it is often called replicated or broadcast join.

Let’s use the following sample data (one fact and two dimension tables):

// Fact table
val flights = sc.parallelize(List(
  ("SEA", "JFK", "DL", "418",  "7:00"),
  ("SFO", "LAX", "AA", "1250", "7:05"),
  ("SFO", "JFK", "VX", "12",   "7:05"),
  ("JFK", "LAX", "DL", "424",  "7:10"),
  ("LAX", "SEA", "DL", "5737", "7:10")))  
// Dimension table
val airports = sc.parallelize(List(
  ("JFK", "John F. Kennedy International Airport", "New York", "NY"),
  ("LAX", "Los Angeles International Airport", "Los Angeles", "CA"),
  ("SEA", "Seattle-Tacoma International Airport", "Seattle", "WA"),
  ("SFO", "San Francisco International Airport", "San Francisco", "CA")))
// Dimension table
val airlines = sc.parallelize(List(
  ("AA", "American Airlines"), 
  ("DL", "Delta Airlines"), 
  ("VX", "Virgin America")))   

We need to join the fact and dimension tables to get the following result:

Seattle           New York       Delta Airlines       418   7:00
San Francisco     Los Angeles    American Airlines    1250  7:05
San Francisco     New York       Virgin America       12    7:05
New York          Los Angeles    Delta Airlines       424   7:10
Los Angeles       Seattle        Delta Airlines       5737  7:10

The fact table be very large, while dimension tables are often quite small. Let’s download the dimension tables to the Spark driver, create maps and broadcast them to each worker node:

val airportsMap = sc.broadcast({case(a, b, c, d) => (a, c)}.collectAsMap)
val airlinesMap = sc.broadcast(airlines.collectAsMap)

Now you can run the map-side join:{case(a, b, c, d, e) => 
    airlinesMap.value.get(c).get, d, e)}.collect

The result of the execution (formatted):

res: Array[(String, String, String, String, String)] = Array(
  (Seattle, New York, Delta Airlines, 418, 7:00), 
  (San Francisco, Los Angeles, American Airlines, 1250, 7:05), 
  (San Francisco, New York, Virgin America, 12, 7:05), 
  (New York, Los Angeles, Delta Airlines, 424, 7:10), 
  (Los Angeles, Seattle, Delta Airlines, 5737, 7:10))

How it Works

First we created a RDD for each table. airports and airlines are dimension tables that we are going to use in map-side join, so we converted them to a map and broadcast to each execution node. Note that we extracted only 2 columns from airports table.

Then we just used map function for each row of flights table, and retrieved dimension values from airportsMap and airlinesMap. If flights table is very large, map function will be executed concurrently for each partition that has own copy of airportsMap and airlinesMap maps.

This approach allows us not to shuffle the fact table, and to get quite good join performance.


I needed to migrate a Map Reduce job to Spark, but this job was previously migrated from SQL :) and contains implementation of FIRST_VALUE, LAST_VALUE, LEAD and LAG analytic window functions in its reducer.

In this article I would like to share some ideas how to implement FIRST_VALUE and LAST_VALUE analytic functions in Spark (not Spark SQL). It is quite easy to extend the code to implement LEAD and LAG functions with any specified offset.

Let’s define a sample data set as follows (the first column is the grouping key, the second is the value):

val data=sc.parallelize(List(
   ("A", "A1"),
   ("A", "A2"),
   ("A", "A3"),
   ("B", "B1"),
   ("B", "B2"),
   ("C", "C1")))

The result we need:

Group key    Value     First Value in Group      Last Value in Group
---------    -----     --------------------      -------------------
        A       A1                       A1                       A3  
        A       A2                       A1                       A3   
        A       A3                       A1                       A3   
        B       B1                       B1                       B2   
        B       B2                       B1                       B2 
        C       C1                       C1                       C1 

I defined a function that returns first and last value in a group:

def firstLastValue(items: Iterable[String]) = for { i <- items 
} yield (i, items.head, items.last)

Now I can group rows by a key and get first and last values:

data.groupByKey().map{case(a, b)=>(a, firstLastValue(b))}.collect 

The result of the execution (formatted):

res: Array[(String, Iterable[(String, String, String)])] = Array(
 (B, List((B1,B1,B2), 
 (A, List((A1,A1,A3), 
 (C, List((C1,C1,C1))))

Note that I used groupByKey, not reduceByKey as we need to work with the entire window.

How It Works

Firstly we use Spark groupByKey function to collect and group all values for each key in the data set. As the result for each key we get the key and the collection of all values for this key.

The next step is to iterate through all values and return a tuple containing the value itself as well as the first and last value in the collection. You can extend this approach to get the values from the previous and following rows with any offset you need.

Multi-Column Key and Value – Reduce a Tuple in Spark

In many tutorials key-value is typically a pair of single scalar values, for example (‘Apple’, 7). But key-value is a general concept and both key and value often consist of multiple fields, and they both can be non-unique.

Consider a typical SQL statement:

SELECT store, product, SUM(amount), MIN(amount), MAX(amount), SUM(units)
FROM sales
GROUP BY store, product

Columns store and product can be considered as a key, and columns amount and units as values.

Let’s implement this SQL statement in Spark. Firstly we define a sample data set:

val sales=sc.parallelize(List(
   ("West",  "Apple",  2.0, 10),
   ("West",  "Apple",  3.0, 15),
   ("West",  "Orange", 5.0, 15),
   ("South", "Orange", 3.0, 9),
   ("South", "Orange", 6.0, 18),
   ("East",  "Milk",   5.0, 5)))

The Spark/Scala code equivalent to the SQL statement is as follows:{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }.
  reduceByKey((x, y) => 
   (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4)).collect

The result of the execution (formatted):

res: Array[((String, String), (Double, Double, Double, Int))] = Array(
  ((West, Orange), (5.0, 5.0, 5.0, 15)), 
  ((West, Apple),  (5.0, 2.0, 3.0, 25)), 
  ((East, Milk),   (5.0, 5.0, 5.0, 5)), 
  ((South, Orange),(9.0, 3.0, 6.0, 27)))

How It Works

We have an input RRD sales containing 6 rows and 4 columns (String, String, Double, Int). The first step is to define which columns belong to the key and which to the value. You can use map function on RDD as follows:{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }

We defined a key-value tuple where key is also tuple containing (store, prod) and value is tuple containing the final results we are going to calculate (amt, amt, amt, units)

Note that we initialized SUM, MIN and MAX with amt, so if there is only one row in a group then SUM, MIN, MAX values will be the same and equal to amt.

The next step is to reduce values by key:

  reduceByKey((x, y) => 

In this function x is the result of reduction of 2 previous values, and y is the current value. Remember that both x and y are tuples containing (amt, amt, amt, units).

Reduce is an associative operation and it works similar to adding 2 + 4 + 3 + 5 + 6 + … You take first 2 values, add them, then take 3rd values, add it and so on.

Now it is easier to understand the meaning of

reduceByKey((x, y) => 
  (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4)).collect

Sum of the previous and current amounts (_1 means the first item of a tuple):

x._1 + y._1

Selecting MIN amount:

math.min(x._2, y._2)

Selecting MAX amount:

math.max(x._3, y._3)

Sum of the previous and current units:

x._4 + y._4