型レベル多項式
Kazuki Moriyama (森山 和樹)
[mathjax]
dottyの柔軟な型を利用して型レベルの多項式とその一般の足し算を定義する。
多項式の定式化
まずはenumを用いて多項式の要素を型に落とし込む。
必要な要素は項とその足し算だ。
enum Polynomial:
case X[Coef <: Int, N <: Int]()
case Plus[XX <: X[?, ?], YY <: Polynomial]()
Xが項、Plusがその足し算を表している。
Xの型変数Coefは係数、Nは右の肩に乗数を表している。
つまりX[2, 3]は[2x^3]のことである。
infix記法を使って2 X 3のように書けばより普通の書き方の様に見える。
定数項は2 X 0の様にして表せる。
3 X 0 // 3
2 X 1 // 2x
(2 X 1) Plus (3 X 0) // 2x + 3
(1 X 2) Plus ((2 X 1) Plus (3 X 0)) // x^2 + 2x + 3
ヘルパー型
後で使いやすいようにXに対してヘルパー型を定義する。
type N[XX <: X[?, ?]] = XX match
case ? X n => n
type C[XX <: X[?, ?]] = XX match
case c X ? => c
それぞれXの型パラメータを抜き出す便利型である。
こんな便利型も昔はむちゃくちゃ頑張らないと書けなかったしいい時代になった。
summon[N[1 X 2] =:= 2]
summon[C[1 X 2] =:= 1]
一般の多項式同士の足し算
Plusでは一般の足し算が表現できないことは上で述べた。
ここで一般の足し算を定義したい。
更に項は乗数のオーダーでソートされてほしい。
そのような型演算は以下のようにして定義できる。
import scala.compiletime.ops.int
import scala.compiletime.ops.any.*
import Polynomial.*
type +[A <: Polynomial, B <: Polynomial] <: Polynomial = A match
case c X n =>
B match
case cc X nn =>
n == nn match
case true => int.+(c, cc) X n
case false =>
int.>(n, nn) match
case true => A Plus B
case false => B Plus A
case xx Plus yy =>
int.>(n, N[xx]) match
case true => A Plus B
case false =>
n == N[xx] match
case true => (int.+(c, C[xx]) X n) Plus yy
case false => xx Plus (A + yy)
case xx Plus yy =>
B match
case cc X nn =>
N[xx] == nn match
case true => (int.+(C[xx], cc) X nn) Plus yy
case false =>
int.>(N[xx], nn) match
case true => xx Plus (B + yy)
case false => B Plus A
case xxx Plus yyy =>
N[xx] == N[xxx] match
case true => (int.+(C[xx], C[xxx]) X N[xx]) Plus (yy + yyy)
case false =>
int.>(N[xx], N[xxx]) match
case true => xx Plus (yy + B)
case false => xxx Plus (yyy + A)
Match Typeが暴れている。
これを使えば以下のように一般の形での多項式の足し算ができる。
summon[(3 X 1) + (4 X 1) =:= (7 X 1)]
summon[(3 X 1) + (2 X 0) =:= ((3 X 1) Plus (2 X 0))]
summon[(3 X 1) + (4 X 2) =:= (4 X 2) + (3 X 1)]
summon[(3 X 2) + (4 X 1) + (2 X 0) =:= ((3 X 2) Plus ((4 X 1) Plus (2 X 0)))]
summon[(3 X 2) + ((4 X 1) + (2 X 0)) =:= (3 X 2) + (4 X
1) + (2 X 0)]
summon[(3 X 1) + ((4 X 1) + (2 X 0)) =:= (7 X 1) + (2 X 0)]
summon[(3 X 0) + ((4 X 1) + (2 X 0)) =:= (4 X 1) + (5 X 0)]
summon[((3 X 1) + (2 X 0)) + (4 X 1) =:= (7 X 1) + (2 X 0)]
summon[((3 X 1) + (2 X 0)) + (4 X 2) =:= (4 X 2) + (3 X 1) + (2 X 0)]
summon[(3 X 1) + ((4 X 2) + (2 X 0)) =:= ((4 X 2) + (3 X 1) + (2 X 0))]
summon[
((5 X 3) + (3 X 1)) + ((4 X 2) + (2 X 0)) =:= (5 X 3) + (4 X 2) + (3 X 1) + (2 X 0)
]
多分あってる。
終わり
dottyの登場によって型レベル演算はscalaコンパイラを熟知したhackyなものではなく、どれだけ型レベルのifを書けるかの勝負になっている感がある。
最後に全コード。
import scala.compiletime.ops.int
import scala.compiletime.ops.any.*
import Polynomial.*
enum Polynomial:
case X[Coef <: Int, N <: Int]()
case Plus[XX <: X[?, ?], YY <: Polynomial]()
type N[XX <: X[?, ?]] = XX match
case ? X n => n
type C[XX <: X[?, ?]] = XX match
case c X ? => c
type +[A <: Polynomial, B <: Polynomial] <: Polynomial = A match
case c X n =>
B match
case cc X nn =>
n == nn match
case true => int.+(c, cc) X n
case false =>
int.>(n, nn) match
case true => A Plus B
case false => B Plus A
case xx Plus yy =>
int.>(n, N[xx]) match
case true => A Plus B
case false =>
n == N[xx] match
case true => (int.+(c, C[xx]) X n) Plus yy
case false => xx Plus (A + yy)
case xx Plus yy =>
B match
case cc X nn =>
N[xx] == nn match
case true => (int.+(C[xx], cc) X nn) Plus yy
case false =>
int.>(N[xx], nn) match
case true => xx Plus (B + yy)
case false => B Plus A
case xxx Plus yyy =>
N[xx] == N[xxx] match
case true => (int.+(C[xx], C[xxx]) X N[xx]) Plus (yy + yyy)
case false =>
int.>(N[xx], N[xxx]) match
case true => xx Plus (yy + B)
case false => xxx Plus (yyy + A)
object Polynomial:
summon[N[1 X 2] =:= 2]
summon[C[1 X 2] =:= 1]
summon[(3 X 1) + (4 X 1) =:= (7 X 1)]
summon[(3 X 1) + (2 X 0) =:= ((3 X 1) Plus (2 X 0))]
summon[(3 X 1) + (4 X 2) =:= (4 X 2) + (3 X 1)]
summon[(3 X 2) + (4 X 1) + (2 X 0) =:= ((3 X 2) Plus ((4 X 1) Plus (2 X 0)))]
summon[(3 X 2) + ((4 X 1) + (2 X 0)) =:= (3 X 2) + (4 X 1) + (2 X 0)]
summon[(3 X 1) + ((4 X 1) + (2 X 0)) =:= (7 X 1) + (2 X 0)]
summon[(3 X 0) + ((4 X 1) + (2 X 0)) =:= (4 X 1) + (5 X 0)]
summon[((3 X 1) + (2 X 0)) + (4 X 1) =:= (7 X 1) + (2 X 0)]
summon[((3 X 1) + (2 X 0)) + (4 X 2) =:= (4 X 2) + (3 X 1) + (2 X 0)]
summon[(3 X 1) + ((4 X 2) + (2 X 0)) =:= ((4 X 2) + (3 X 1) + (2 X 0))]
summon[
((5 X 3) + (3 X 1)) + ((4 X 2) + (2 X 0)) =:= (5 X 3) + (4 X 2) + (3 X 1) + (2 X 0)
]