Logo
Englika

How to get N rows per group in SQL

How to get N rows per group in SQL

Let's assume that you are developing the home page in an online store. This page should display product categories with 10 products in each. How do you make a query to the database? What indexes do you need to create to speed up query execution?

Let's create the products table and fill it with data:

CREATE TABLE products (id int, category_id int, name varchar);
INSERT INTO products SELECT i, FLOOR(RANDOM() * 5 + 1), 'name' || i FROM generate_series(1, 1000000) i;

The table will have 1 millions products divided into 5 categories.

The first thing you might think of is using LIMIT with GROUP BY, but it won't work, because at the beginning each group will be collapsed into one row, and after that LIMIT will be applied. In order for all rows from one group to collapse into one row, an aggregate function (COUNT, SUM, AVG, etc.) must be used for all other columns, otherwise we will get an error.

SELECT COUNT(id), category_id FROM products GROUP BY category_id LIMIT 10;
 count  | category_id 
--------+-------------
 199645 |           1
 199753 |           2
 199833 |           3
 200037 |           4
 200732 |           5

This is definitely not what we want.

There are 2 ways to solve this problem. Let's take a look at both ways, understand which indexes are needed to speed up queries and determine which SQL-query is the fastest.

Way 1. Using a window function

The window function allows you to perform certain actions on rows from each group, just as aggregate functions (COUNT, SUM, AVG, etc.) do, but rows from the same group do not collapse into one row.

To make it clearer, let's add a new deals table and execute an aggregate function and a window function.

CREATE TABLE deals (id int, user_id int);
INSERT INTO deals SELECT i, i % 2 + 1 FROM generate_series(1, 7) i;

There are 7 deals in the table. Ther first user has 3 deals, the second has 4 deals.

At the first, let's use the aggregate function COUNT.

SELECT COUNT(*) FROM deals;
 count 
-------
     7
(1 row)

The aggregate function COUNT summed up all the rows and output the sum as a single row. If we add GROUP BY user_id, the sum of all deals for each user will be displayed.

SELECT user_id, COUNT(*) FROM deals GROUP BY user_id;
 user_id | count 
---------+-------
       2 |     4
       1 |     3
(2 rows)

Now let's look at the window function COUNT. To use it, you need to write OVER() after the function.

SELECT id, COUNT(*) OVER() FROM deals;
 id | count 
----+-------
  1 |     7
  2 |     7
  3 |     7
  4 |     7
  5 |     7
  6 |     7
  7 |     7
(7 rows)

As you can see, the number of all rows was also calculated, but unlike SELECT COUNT(*) FROM deals, this number is output in each row, not in one row. If OVER is specified without parameters, then 1 group (window) will include all rows from a table.

To display the number of deals by user, you need to make a window not the entire table, but a group using PARTITION BY.

SELECT *, COUNT(*) OVER(PARTITION BY user_id) FROM deals;
 id | user_id | count 
----+---------+-------
  4 |       1 |     3
  2 |       1 |     3
  6 |       1 |     3
  1 |       2 |     4
  5 |       2 |     4
  7 |       2 |     4
  3 |       2 |     4
(7 rows)

As window functions, you can use either any aggregate or specialized window functions.

Let's use the ROW_NUMBER window function, which outputs the number of the current row in a group.

SELECT *, ROW_NUMBER() OVER(PARTITION BY user_id) FROM deals;
 id | user_id | row_number 
----+---------+------------
  4 |       1 |          1
  2 |       1 |          2
  6 |       1 |          3
  1 |       2 |          1
  5 |       2 |          2
  7 |       2 |          3
  3 |       2 |          4
(7 rows)

We have a deal sequence number for each user. Now, with an addition query, you can filter out all rows that have an sequence number less than or equal to 2. As a result, we will get 2 deals for each user.

SELECT * FROM (
    SELECT *, ROW_NUMBER() OVER(PARTITION BY user_id) FROM deals
) t WHERE row_number <= 2;
 id | user_id | row_number 
----+---------+------------
  4 |       1 |          1
  2 |       1 |          2
  1 |       2 |          1
  5 |       2 |          2
(4 rows)

You can also set the order in which the rows will be processed by the window function using ORDER BY inside OVER.

SELECT * FROM (
    SELECT *, ROW_NUMBER() OVER(PARTITION BY user_id ORDER BY id) FROM deals
) t WHERE row_number <= 2;
 id | user_id | row_number 
----+---------+------------
  2 |       1 |          1
  4 |       1 |          2
  1 |       2 |          1
  3 |       2 |          2
(4 rows)

The rows within each group are now sorted by id column.

Let's go back to the products table with 1 million rows and see how long such a query will be executed.

EXPLAIN ANALYZE SELECT * FROM (
    SELECT *, ROW_NUMBER() OVER (
        PARTITION BY category_id ORDER BY id DESC
    )
    FROM products
) t WHERE row_number <= 10;
 Subquery Scan on t  (cost=136536.84..169036.84 rows=333333 width=26) (actual time=3128.932..3912.707 rows=50 loops=1)
   Filter: (t.row_number <= 10)
   Rows Removed by Filter: 999950
   ->  WindowAgg  (cost=136536.84..156536.84 rows=1000000 width=26) (actual time=3128.929..3855.424 rows=1000000 loops=1)
         ->  Sort  (cost=136536.84..139036.84 rows=1000000 width=18) (actual time=3123.940..3440.984 rows=1000000 loops=1)
               Sort Key: products.category_id, products.id DESC
               Sort Method: external merge  Disk: 28304kB
               ->  Seq Scan on products  (cost=0.00..16369.00 rows=1000000 width=18) (actual time=0.060..53.225 rows=1000000 loops=1)
 Planning Time: 1.518 ms
 Execution Time: 3915.502 ms

The query was executed for almost 4 seconds.

As you can see, a subquery was executed at the beginning. To do this, the entire products table was scanned, sorted by the id column in descending order, and a window function was executed that numbered the rows within each group. Then the main query was executed, which filtered out rows with an sequence number less than or equal to 10.

To speed up this query, let's add the B-Tree index for 2 columns: category_id and id, so that rows from the products table are selected immediately in sorted order.

CREATE INDEX products_category_id_id_idx ON products (category_id, id DESC);

Let's execute the query again.

EXPLAIN ANALYZE SELECT * FROM (
    SELECT *, ROW_NUMBER() OVER (
        PARTITION BY category_id ORDER BY id DESC
    )
    FROM products
) t WHERE row_number <= 10;
 Subquery Scan on t  (cost=0.42..81053.98 rows=333333 width=26) (actual time=1.585..686.557 rows=50 loops=1)
   Filter: (t.row_number <= 10)
   Rows Removed by Filter: 999950
   ->  WindowAgg  (cost=0.42..68553.98 rows=1000000 width=26) (actual time=1.583..619.792 rows=1000000 loops=1)
         ->  Index Scan using products_category_id_id_idx on products  (cost=0.42..51053.98 rows=1000000 width=18) (actual time=1.568..180.026 rows=1000000 loops=1)
 Planning Time: 3.580 ms
 Execution Time: 686.628 ms

Now the query executed in 687 ms.

If, as in my case, you need to group not by 1 column, but by multiple, then after PARTITION BY you can specify multiple columns. For example, PARTITION BY record_id, field_id.

If pagination is used to display categories (if there are a lot of categories), then in the subquery you need to add the WHERE category_id IN (1,2,4) condition, where 1,2,4 is the id of the categories that you want to show on this page (assume that the category with id=3 has been deleted). In this case, the query will be executed even faster.

The main disadvantage of this way is that at the first stage (in a subquery) it is necessary to select all the rows that are in the specified groups (categories of products). These rows are numbered by the window function, and after that, the first N rows in each group are filtered.

Is it possible to select N rows for each group at once, so as not to waste time reading all the rows and filtering?

Way 2. Using the JOIN LATERAL

The ideal solution to our problem would be to perform only 2 steps:

  1. Take all the groups. In our example, these are product categories.
  2. For each group, take only N results and combine them into one table.

Let's first select all the product categories that are used in the products table. To do this, we can use DISTINCT, which will take only rows with a unique category_id value from the table. Read more about it here.

SELECT DISTINCT category_id FROM products;
 category_id 
-------------
           1
           3
           5
           4
           2
(5 rows)

Now, using JOIN LATERAL, we will take N results for each group and combine it into one table. LATERAL allows you to refer to previously declared tables in a JOIN query. Thus, we can get all the categories at the beginning, and then take N products for each category.

SELECT r.*
FROM (
    SELECT DISTINCT category_id FROM products ORDER BY category_id
) categories
JOIN LATERAL (
    SELECT * FROM products
    WHERE category_id = categories.category_id
    ORDER BY id DESC
    LIMIT 2
) r ON true;
   id    | category_id |    name     
---------+-------------+-------------
  999997 |           1 | name999997
  999995 |           1 | name999995
  999992 |           2 | name999992
  999984 |           2 | name999984
  999999 |           3 | name999999
  999998 |           3 | name999998
  999996 |           4 | name999996
  999994 |           4 | name999994
 1000000 |           5 | name1000000
  999993 |           5 | name999993
(10 rows)
 Nested Loop  (cost=18869.53..18873.03 rows=10 width=18) (actual time=225.723..225.749 rows=10 loops=1)
   ->  Sort  (cost=18869.11..18869.12 rows=5 width=4) (actual time=225.675..225.676 rows=5 loops=1)
         Sort Key: products.category_id
         Sort Method: quicksort  Memory: 25kB
         ->  HashAggregate  (cost=18869.00..18869.05 rows=5 width=4) (actual time=225.648..225.650 rows=5 loops=1)
               Group Key: products.category_id
               Batches: 1  Memory Usage: 24kB
               ->  Seq Scan on products  (cost=0.00..16369.00 rows=1000000 width=4) (actual time=0.012..55.384 rows=1000000 loops=1)
   ->  Limit  (cost=0.42..0.73 rows=2 width=18) (actual time=0.013..0.013 rows=2 loops=5)
         ->  Index Scan using products_category_id_id_idx on products products_1  (cost=0.42..30662.64 rows=200000 width=18) (actual time=0.012..0.013 rows=2 loops=5)
               Index Cond: (category_id = products.category_id)
 Planning Time: 0.172 ms
 Execution Time: 225.824 ms

The query was executed in 226 ms, which is 3 times faster than the query in which the window function was used. The query is executed as follows:

  1. All categories that are used in the products table are selected and sorted in ascending order.
  2. For each category that was selected in the first subquery, the last 2 products are selected. ON true means that all received rows will be compined with rows from the table with categories.
  3. Only columns from the JOIN table are selected to avoid duplicating the category_id column.

To speed up the query, we need the same index that we have already created earlier:

CREATE INDEX products_category_id_id_idx ON products (category_id, id DESC);

To execute the first subquery, a sequential scan was performed. If you select only a part of categories (e.g. 2 out of 5), then the index will be used.

You may notice that it takes almost all the time to execute the first subquery, which searches for all the categories used in the products table.

EXPLAIN ANALYZE SELECT DISTINCT category_id FROM products ORDER BY category_id;
 Sort  (cost=18869.11..18869.12 rows=5 width=4) (actual time=225.985..225.986 rows=5 loops=1)
   Sort Key: category_id
   Sort Method: quicksort  Memory: 25kB
   ->  HashAggregate  (cost=18869.00..18869.05 rows=5 width=4) (actual time=225.961..225.963 rows=5 loops=1)
         Group Key: category_id
         Batches: 1  Memory Usage: 24kB
         ->  Seq Scan on products  (cost=0.00..16369.00 rows=1000000 width=4) (actual time=0.010..56.372 rows=1000000 loops=1)
 Planning Time: 0.076 ms
 Execution Time: 226.043 ms

Most likely, you already have a separate table with product categories.

CREATE TABLE categories (id int, name varchar);
INSERT INTO categories SELECT i, 'name' || i FROM generate_series(1, 5) i;

So it makes no sense to scan the entire products table to search for these categories. Let's take these categories right from the categories table.

EXPLAIN ANALYZE SELECT r.*
FROM (
    SELECT id FROM categories ORDER BY id
) categories
JOIN LATERAL (
    SELECT * FROM products
    WHERE category_id = categories.id
    ORDER BY id DESC
    LIMIT 2
) r ON true;
 Nested Loop  (cost=1.53..5.03 rows=10 width=18) (actual time=0.083..0.127 rows=10 loops=1)
   ->  Sort  (cost=1.11..1.12 rows=5 width=4) (actual time=0.038..0.040 rows=5 loops=1)
         Sort Key: categories.id
         Sort Method: quicksort  Memory: 25kB
         ->  Seq Scan on categories  (cost=0.00..1.05 rows=5 width=4) (actual time=0.012..0.014 rows=5 loops=1)
   ->  Limit  (cost=0.42..0.73 rows=2 width=18) (actual time=0.015..0.016 rows=2 loops=5)
         ->  Index Scan using products_category_id_id_idx2 on products  (cost=0.42..30662.64 rows=200000 width=18) (actual time=0.014..0.015 rows=2 loops=5)
               Index Cond: (category_id = categories.id)
 Planning Time: 0.301 ms
 Execution Time: 0.226 ms

The query was executed in 0.2 ms, 1000 times faster than the previous query with DISTINCT and 3000 times faster than the query with a window function.

Conclusion

To get N rows in each group, it is better to use a query with JOIN LATERAL, which is much faster than a query with a window function. This is due to the fact that when using the window function, all rows in each group are selected at the first step, and after that unnecessary rows are cut off. JOIN LATERAL allows you to refer to already selected groups and take N rows for each group at once.

You can select groups in one of 3 ways:

  1. Select them from a separate table. For example, SELECT id FROM categories.
  2. Generate them if they are already known. For example, SELECT id FROM generate_series(1, 5) i.
  3. Select them from the same table. For example, SELECT DISTINCT category_id FROM products.

To speed up the search, you need to create an index for the columns used.

Cheat sheet

-- Create tables with categories and products
CREATE TABLE categories (id int, name varchar);
CREATE TABLE products (id int, category_id int, name varchar);

-- Fill in the tables
INSERT INTO categories SELECT i, 'name' || i FROM generate_series(1, 5) i;
INSERT INTO products SELECT i, FLOOR(RANDOM() * 5 + 1), 'name' || i FROM generate_series(1, 1000000) i;

-- Create a B-Tree index to speed up the search
CREATE INDEX products_category_id_id_idx ON products (category_id, id DESC);

-- Select the last 2 products in each category
EXPLAIN ANALYZE SELECT r.*
FROM (
    SELECT id FROM categories ORDER BY id
) categories
JOIN LATERAL (
    SELECT * FROM products
    WHERE category_id = categories.id
    ORDER BY id DESC
    LIMIT 2
) r ON true;

Related posts

How best to store date ranges in PostgreSQL

In some cases, it is necessary to store a range of dates in the database (for example, hotel room reservations, calendar events). Subsequently, it is usually required to find records in the database whose date range intersects with the specified one. For...

How best to store date ranges in PostgreSQL