それKotlinなら5行でできるよ

このブログは NewsPicks Advent Calendar 2022 2日目の記事です。

qiita.com

こんにちは。むとうです。

プログラミングって難しいですよね。昔スゴーク頑張って何十行も書いたコードをベテランの人に一行に直されて、衝撃が走ったことを覚えています。手練との歴然とした力の差を見せつけられる、みなさんもそういう経験があるのではないでしょうか?

私はかれこれ20年近くプログラミングを教えてきており、人のコードの書きっぷりを見て「これを身につければもっといい感じにできるのに」と感じることも増えてきました。そういう技は決して難しいものばかりではありません。誰でもすぐに身につけられて一生もののスキルとして使える技の一つに「コレクション処理」があります。実用的なプログラムの多くの部分はコレクション操作で成り立っています。コレクション操作を適切につかえば難しい仕掛けや独自実装をせずともやりたいことができるケースは多いものです。

今日は具体的な例をみながら、Kotlinのコレクション関数をご紹介します。

まずはJavaで書かれたこちらのコードをご覧ください。

    public List<UserModel> getFollowingsUsers(Integer userId, Integer limit) {
        // フォロー中のユーザ一覧を取得
        Collection<Integer> followingIds = getCachedFollowingIds(userId);
        if (followingIds.isEmpty()) {
            return Collections.emptyList();
        }

        List<Integer> reverseSortedIds = new ArrayList<>(followingIds);
        Collections.reverse(reverseSortedIds);
        // getFollowings()と同じソート順とする
        // https://docs.google.com/document/...
        List<Integer> sortedUserIds = reverseSortedIds.stream().sorted(followingComparator()).collect(Collectors.toList());

        // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む
        int cursor = 0;
        List<Integer> normalUserIds = new ArrayList<Integer>();
        do {
            // カーソルからMAX_SIZE件を切り出す
            List<Integer> fetchUserIds = sortedUserIds.stream().skip(cursor).limit(MAX_SIZE).collect(Collectors.toList());
            List<Integer> filteredUserIds = userService.filterNormalUserIds(fetchUserIds);
            normalUserIds.addAll(filteredUserIds);
            cursor += MAX_SIZE;
        } while ((normalUserIds.size() < limit) && (cursor < sortedUserIds.size()));
        // filterNormalUserIds()でinを使用してテーブル検索をしているため、ソート順が保証されていないので、元のリストのソート順でソートして上位limit件を取り出す
        List<Integer> followingNormalUserIds = normalUserIds.stream()
                                                            .map(user -> sortedUserIds.indexOf(user))
                                                            .sorted()
                                                            .limit(limit)
                                                            .map(index -> sortedUserIds.get(index))
                                                            .collect(Collectors.toList());

        // ユーザモデルを生成して返却する
        return followingNormalUserIds.stream().map(userService::getUser).filter(Objects::nonNull).collect(Collectors.toList());
    }

大変な力作です。33行と決して長くはないコードですが、ひと目見て何をやっているのかがわかりにくいためコメントを沢山書いてもらっています。しかしそれでも読み解くのに苦労するでしょう。直し甲斐のあるコードです。ワクワクしてきますね!

まずは頭を使わずにKotlinに変換してしまいましょう。IDEAを使っている人であれば、Javaのコードをコピーしてktファイルにペーストするだけで変換が可能です。

   fun getFollowingsUsers(userId: Int?, limit: Int): List<UserModel?>? {
        // フォロー中のユーザ一覧を取得
        val followingIds = getCachedFollowingIds(userId)
        if (followingIds.isEmpty()) {
            return Collections.emptyList()
        }
        val reverseSortedIds: List<Int> = ArrayList(followingIds)
        Collections.reverse(reverseSortedIds)
        // getFollowings()と同じソート順とする
        // https://docs.google.com/document/...
        val sortedUserIds: List<Int> = reverseSortedIds.stream().sorted(followingComparator()).collect(Collectors.toList())

        // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む
        var cursor = 0
        val normalUserIds: MutableList<Int> = ArrayList()
        do {
            // カーソルからMAX_SIZE件を切り出す
            val fetchUserIds: List<Int> = sortedUserIds.stream().skip(cursor.toLong()).limit(MAX_SIZE.toLong()).collect(Collectors.toList())
            val filteredUserIds: List<Int> = userService.filterNormalUserIds(fetchUserIds)
            normalUserIds.addAll(filteredUserIds)
            cursor += MAX_SIZE
        } while (normalUserIds.size < limit && cursor < sortedUserIds.size)
        // filterNormalUserIds()でinを使用してテーブル検索をしているため、ソート順が保証されていないので、元のリストのソート順でソートして上位limit件を取り出す
        val followingNormalUserIds: List<Int> = normalUserIds.stream()
            .map { user: Int -> sortedUserIds.indexOf(user) }
            .sorted()
            .limit(limit.toLong())
            .map { index: Int? -> sortedUserIds[index!!] }
            .collect(Collectors.toList())

        // ユーザモデルを生成して返却する
        return followingNormalUserIds.stream().map { userId: Int? -> userService.getUser(userId) }.filter(Objects::nonNull).collect(Collectors.toList())
    }

まだこれではKotlinらしい感じがしません。いろいろと冗長なところを整理してみます。

次のようにします。

  • 引数・返り値に?がついてしまったけど実際はnullがこないところの ? を取る。
  • stream()....collect(Collectors.toList())→単純に削除
  • 追記用の ArrayList の作成 → ArrayList クラスにこだわる場面ではないので、Kotlin標準の関数を使う。Kotlinでは変更用のコレクションはmutableを明示する必要があるので listOf<>() ではなく mutableListOf<>() を使う。
  • streamの関数をKotlin標準コレクション関数に置き換える。
stream Kotlin
map { xyz: Xyz -> f(xyz) } 冗長な型の指定をなくし、itを使う。 map { f(it) }
.limit() .take()
.skip() .drop()
Collections.reverse(x) x.reversed()
.sorted(comparator) .sortedWith(comparator)

コンパレータを使わない場合は .sorted() を使う。
ソートキーを関数で指定する場合は .sortedBy { ... } を使う。
.filter(Objects::nonNull) 直前の.mapと組み合わせて.mapNotNull()を使う。

このようになります。

   fun getFollowingsUsers(userId: Int, limit: Int): List<UserModel> {
        // フォロー中のユーザ一覧を取得
        val followingIds = getCachedFollowingIds(userId)
        if (followingIds.isEmpty()) {
            return emptyList()
        }
        val reverseSortedIds = followingIds.reversed()

        // getFollowings()と同じソート順とする
        // https://docs.google.com/document/...
        val sortedUserIds = reverseSortedIds.sortedWith(followingComparator())

        // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む
        var cursor = 0
        val normalUserIds = mutableListOf<Int>()
        do {
            // カーソルからMAX_SIZE件を切り出す
            val fetchUserIds = sortedUserIds.drop(cursor).take(MAX_SIZE)
            val filteredUserIds = userService.filterNormalUserIds(fetchUserIds)
            normalUserIds.addAll(filteredUserIds)
            cursor += MAX_SIZE
        } while (normalUserIds.size < limit && cursor < sortedUserIds.size)

        // filterNormalUserIds()でinを使用してテーブル検索をしているため、ソート順が保証されていないので、元のリストのソート順でソートして上位limit件を取り出す
        val followingNormalUserIds = normalUserIds
            .map { sortedUserIds.indexOf(it) }
            .sorted()
            .take(limit)
            .map { sortedUserIds[it] }

        // ユーザモデルを生成して返却する
        return followingNormalUserIds.mapNotNull { userService.getUser(it) }
    }

だいぶスッキリしてきました!やっと読む気になるコードになりました。

少しコードを読むための前提知識をご説明します。NewsPicksでは「ユーザー」は単純なユーザーだけではなく、記事を提供してくれるサプライヤーも含まれています。いまみている関数は引数で与えられたuserIdを持つユーザーがフォローしている対象のうち単純なユーザーだけに絞り込んだ結果を取得しようとしているものです。

getCachedFollowingIds() はフォロー対象のユーザーIDを返しますが、IDはその実単なる数値です。なのでDBに問い合わせて実際にユーザーなのかどうかを確認する必要があるということのようです。userService.filterNormalUserIds(fetchUserIds)はそのためのメソッドです。

仕様を整理します。

  • getCachedFollowingIds(userId: Int): List<Int>
    • フォローしているユーザーIDをキャッシュから返す。
  • userService.filterNormalUserIds(fetchUserIds: List<Int>): List<Int>
    • 引数のユーザーIDのうち実際にユーザーであるIDだけをDBに問い合わせて返す。引数を sqlのin句に入れて処理をするため大量のデータを入力してはいけない。返り値の順序は不定。

さて、サービス上フォロー数の上限はありませんので、数万ユーザーをフォローしている人がいた場合userService.filterNormalUserIds() は使えないようです。またfilterNormalUserIds() の返り値の順序が不定なところも問題の複雑さに拍車をかけています。

真ん中の部分の // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む が重要そうですね。DBへの問い合わせの入力が大きくなりすぎないように制御しています。ここから見ていきましょう。

        // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む
        var cursor = 0 // 何件処理したか
        val normalUserIds = mutableListOf<Int>() // 結果を入れる場所
        do {
            // カーソルからMAX_SIZE件を切り出す
            val fetchUserIds = sortedUserIds.drop(cursor).take(MAX_SIZE) // 最大処理数ずつ取り出す
            val filteredUserIds = userService.filterNormalUserIds(fetchUserIds) // DBに問い合わせ
            normalUserIds.addAll(filteredUserIds) // 結果を貯める
            cursor += MAX_SIZE // 処理数を足す
        } while (normalUserIds.size < limit && cursor < sortedUserIds.size) // 必要数(limit)を満たしていない かつ 入力をすべて処理していない間は繰り返す

処理内容にコメントをしました。意図していることはこのようなことだと考えられます。

  • DBへの問い合わせを成功させるため、MAX_SIZEずつ処理する。
  • DBへの不要な問い合わせを減らすため、必要数を満たしたところで処理を打ち切る。

MAX_SIZEずつ処理するためには、一つのリストをMAX_SIZEずつに区切ることができれば良さそうです。Kotlinにはこのための便利な関数があります。.chunked()です。言葉だけではわかりにくいので図示するとこのようになります。

chunked()がリストの要素をチャンクに区切る様子
chunked()

より複雑なユースケースについては、.windowed()で対応出来ます。.chunked()に加えて、前後のチャンクで要素を重ねたい場合にも対応出来ます。

windowed()がリストの要素を分割する様子
windowed()の挙動

今回は.chunked()を使って書き直してみましょう。

        // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む
        val normalUserIds = sortedUserIds
            .chunked(MAX_SIZE) // MAX_SIZEずつ処理する
            .flatMap { userService.filterNormalUserIds(it) }
            .take(limit)

新たに出てきた.flatMap().map().flatten()を組み合わせた挙動になります。

flatMap()がリストのリストを一つのリストにする様子
flatMap()

このままではまだ足りません。そうです、DBへの不要な問い合わせを減らすことが出来ていません。でも安心してください!Kotlinにはこのための便利な機能があります。Sequence<T>です。Sequence<T>はリストとは違い、各要素がメモリ上には存在せず実際に必要になるまで処理が遅延された状態になっています。他の言語では遅延シーケンスや遅延リストなどということもあるデータ構造です。.asSequence()を使うとSequence<T>を作ることが出来ます。Sequence<T>に対してはいままで行ってきたコレクションに対する処理は全て同じように行なえます。やってみましょう。

        // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む
        val normalUserIds = sortedUserIds
            .asSequence() // 必要なだけの処理を行えるようにする ← NEW!
            .chunked(MAX_SIZE) // MAX_SIZEずつ処理する
            .flatMap { userService.filterNormalUserIds(it) }
            .take(limit)

完璧です!スッキリしました!

次にDBに問い合わせた結果の順序が不定だった問題を解決しましょう。以下の部分です。

        // filterNormalUserIds()でinを使用してテーブル検索をしているため、ソート順が保証されていないので、元のリストのソート順でソートして上位limit件を取り出す
        val followingNormalUserIds = normalUserIds
            .map { sortedUserIds.indexOf(it) }
            .sorted()
            .take(limit)
            .map { sortedUserIds[it] }

DBへの問い合わせの結果得られた normalUserIdssortedUserIdsの順序に直したいわけなのですが、ちょっともう読むのが面倒になってきたので答えから言ってしまいます。 sortedUserIds から先程作った normalUserIds と重複する部分だけを残すことでほしいものが得られます。.intersect()を使います。図にするとこのようになります。

intersect() が重複する要素だけを残す様子
intersect()

ありがたいことに.intersect()の結果は元のリストの順序を保ってくれています!早速書き直してみましょう。

        // filterNormalUserIds()でinを使用してテーブル検索をしているため、ソート順が保証されていないので、元のリストのソート順でソートして上位limit件を取り出す
        val followingNormalUserIds =  sortedUserIds.intersect(normalUserIds.toSet())

.intersect() の引数には Sequence<T> をそのまま渡すことは出来ないので、遅延された中身を実体化しています。今回は「含まれているかどうか」を確認するためだけのものなので.toSet()を使いセットを作っています。(リストが欲しいケースでは.toList()を使うことができます) normalUserIdsにはすでに高々limit個のデータが入っているはずなのでここで改めて.take()する必要はもうないでしょう。

IOの最適化の観点からは、N+1問題を避けるために最後のこの処理も見直したいです。

        // ユーザモデルを生成して返却する
        return followingNormalUserIds.mapNotNull { userService.getUser(it) }

実はuserService.getUser()には複数処理版も存在していて、そちらを使うとIOを一つにまとめられて高速です。大変ありがたいことにこの処理はちゃんと順序を保ってくれます!

        // ユーザモデルを生成して返却する
        return userService.getUsers(followingNormalUserIds)

まとめるとこのようになりました。

    fun getFollowingsUsers(userId: Int, limit: Int): List<UserModel> {
        // フォロー中のユーザ一覧を取得
        val followingIds = getCachedFollowingIds(userId)
        val reverseSortedIds = followingIds.reversed()

        // getFollowings()と同じソート順とする
        // https://docs.google.com/document/...
        val sortedUserIds = reverseSortedIds.sortedWith(followingComparator())

        // フォロー中のユーザ一覧をUserType.NORMALのもので絞り込む
        val normalUserIds = sortedUserIds
            .asSequence()
            .chunked(MAX_SIZE)
            .flatMap { userService.filterNormalUserIds(it) as List<Int> }
            .take(limit)

        // filterNormalUserIds()でinを使用してテーブル検索をしているため、ソート順が保証されていないので、元のリストのソート順でソートして上位limit件を取り出す
        val followingNormalUserIds = sortedUserIds.intersect(normalUserIds.toSet())

        // ユーザモデルを生成して返却する
        return userService.getUsers(followingNormalUserIds)
    }

だいぶ短くなりました!ちょっとした修正ですがフォローが0の場合後続の処理で何も行われないだけなので0件の場合の例外処理は削除しました。 33行が22行になり、行も短くなって最初に比べて遥かにわかりやすくなったのではないかと思います。

ところでここで実は大事な話があります。通常のユーザーかどうかの判定はDBに問い合わせる必要はなく、IDを形式的に判断するだけで事足りるのです。というわけで、いままでの苦労は何だったのか?ソフトウェア開発にとって最も大事なのはドメイン知識なのです!

   fun getFollowingsUsers(userId: Int, limit: Int): List<UserModel> {
        val followingUserIds = getCachedFollowingIds(userId)
            .reversed()
            .filterNot { UserId.isSystemUser(it) }
            .sortedWith(followingComparator())
            .take(limit)
        return userService.getUsers(followingUserIds)
    }

ここまでくるとメソッドのドキュメンテーションコメントだけで十分で、実装にはコメントはもはや不要です。新しく.filterNot() が出てきましたが、ここまでお付き合いいただいたあなたにはもはや何も言うことはないでしょう。

Page top