NdArrayを例にPythonのTypeVarTupleの力を確認する

 
0
このエントリーをはてなブックマークに追加
Kazuki Moriyama
Kazuki Moriyama (森山 和樹)

TypeVarTupleとは

簡単に言えば可変長なtype parameter.
複数の型を一つの型変数で表現できる。

例えば下の例ではtupleの長さによって戻り値の型が変わっている。

T = TypeVar("T")
Ts = TypeVarTuple("Ts")

def move_first_element_to_last(tup: tuple[T, *Ts]) -> tuple[*Ts, T]:
    return (*tup[1:], tup[0])

move_first_element_to_last(tup=(1,)) # tuple[int]
move_first_element_to_last(tup=(1, 'spam')) # tuple[str, int]
move_first_element_to_last(tup=(1, 'spam', 3.0)) # tuple[str, float, int]
move_first_element_to_last(tup=()) # type check fails

TypeVarTupleのモチベーション

大きなユースケースの一つはnumpyなどの数値計算ライブラリで用いるndarray系のデータ型にうまい具合に型を付けられるようになること。
つまりndarrayは多くの場合はネストした配列の様なデータ構造になり、それらのデータ型が型として表現できると嬉しいね、ということ。

例えば下の様な型で各配列のディメンションを表現できると嬉しい。

from typing import Literal as L

# [[1.0, ...640個], ...480個]で中身がfloatな配列
NdArray[float, L[480], L[640]]
# [[[1.0, ...100個], ...400個], ...500個]で中身がintな配列
NdArray[int, L[500], L[400], L[100]]

上の例ではディメンションを表現する部分のジェネリクス(L[500]とか)の数が変わっているが、この様に可変長にするためにTypeVarTupleが使える。

実際にNdArrayを作ってみる

モチベーションとユースケースはなんとなくわかったので、実際にNdArrayを作ってみることで実装に即した型がどこまで厳密に定義できるかを見ていく。
使用するライブラリなどは以下。

  • python3.12
  • 内部データ型としてnumpy
  • 型チェックはpyright(~pylance)

定義

おおよそクラス自体の見た目は以下の様になる。

from typing import TypeVarTuple, Generic, TypeVar

DType = TypeVar("DType")
Shape = TypeVarTuple("Shape")

class NdArray(Generic[DType, *Shape]): ...

これでモチベーションのとこで例示したような型の定義が使えるようになる。

NdArray[float, L[480], L[640]]
NdArray[int, L[500], L[400], L[100]]

初期化処理

import numpy as np

DType = TypeVar("DType")
Shape = TypeVarTuple("Shape")

Ts = TypeVarTuple("Ts")

Data = list[DType] | list["Data[DType]"]

class NdArray(Generic[DType, *Shape]):
    def __init__(self, shape: tuple[*Shape], data: Data[DType]):
        for s in shape:
            if not isinstance(s, int):
                raise TypeError("Shape must be a tuple of integers")

        if not self._check_shape(shape, data):
            raise ValueError("Data is not in the shape specified by the shape argument")

        self.shape = shape
        self.data = np.array(data)

    def _check_shape(self, shape: tuple[*Ts], data: Data[DType]) -> bool:
        """Check if data is in the shape specified by the shape argument."""
        if not isinstance(data, list):
            return False

        match len(shape):
            case 1:
                return len(data) == shape[0]
            case _:
                return len(data) == shape[0] and all(
                    self._check_shape(shape[1:], cast(Data[DType], d)) for d in data
                )

初期化のときには一応指定されたshapeと実際の型が等しいかチェックするのが望ましい。
上の様なコードを書けば一応初期化時の検査が行われ、不正なデータの場合には例外が投げられる。

NdArray((2, 2), [1, 2]) # ValueError
NdArray((2, 2), [[1, 2], [3, 4], [5, 6]]) # ValueError
NdArray((2, 2), [[1, 2], [3, 4]]) # Valid

dataの型がData[DType]となっているがこれは再帰的にlistのネストを許したデータのみ許可したいから。

Data = list[DType] | list["Data[DType]"]

例えばData[int]は以下の様な挙動になる。

# valid
a: Data[int] = [1]
b: Data[int] = [[1, 2], [3, 4]]

# error: Expression of type "Literal[1]" cannot be assigned to declared type "Data[int]"
#   Type "Literal[1]" cannot be assigned to type "Data[int]"
#   "Literal[1]" is incompatible with "list[int]"
#   "Literal[1]" is incompatible with "list[Data]" (reportGeneralTypeIssues)
c: Data[int] = 1

これでそもそもlistとして不正なものを型チェックで防ぎつつ、厳密な検証はruntimeで行うことで正しいNdArrayの構築ができる。

また作成されたNdArrayの型は推論がうまく効かないので明示的に書いてあげるのが良い。

arr1: NdArray[int, L[2], L[2]] = NdArray((2, 2), [[1, 2], [3, 4]])
# or
arr1 = NdArray[int, L[2], L[2]]((2, 2), [[1, 2], [3, 4]])

arange

numpyはndarrayを適当な長さで生成するarange関数が存在する。
こちらも長さをリテラル型で取れるので、型がついた状態でNdArrayを生成できる。

Len = TypeVar("Len", bound=int)

class NdArray(Generic[DType, *Shape]):
    ...
    @staticmethod
    def arange(l: Len) -> "NdArray[int, Len]":
        return NdArray((l,), list(np.arange(l)))
arr: NdArray[int, L[5]] = NdArray.arange(5)

assert arr.array_equal(NdArray[int, L[5]]((5,), [0, 1, 2, 3, 4]))

# error: Argument of type "Literal[6]" cannot be assigned to parameter "l" of type "Len@arange" in function "arange"
#   "Literal[6]" cannot be assigned to type "Literal[5]" (reportGeneralTypeIssues)
arr: NdArray[int, L[5]] = NdArray.arange(6)

等値比較

numpy.ndarrayは全要素が等しいかを比較するには==(__eq__)での比較ではなく、numpy.array_equalによる比較を行う。
同じ様なメソッドをNdArrayに生やしてみる。

class NdArray(Generic[DType, *Shape]):
    ...
    def array_equal(self, other: "NdArray[DType, *Shape]") -> bool:
        return np.array_equal(self.data, other.data)

shapeの比較は型の検査によって行われるのでruntimeの処理としては記述していない。(pythonはtype hintを使用していないユーザの使用も想定する必要があると思うので、本来であればruntimeの検査も記述する必要があると思われるが。)

これを使えばすべての値を見た比較ができる。

arr1 = NdArray[int, L[2], L[2]]((2, 2), [[1, 2], [3, 4]])
arr2 = NdArray[int, L[2], L[2]]((2, 2), [[1, 2], [3, 4]])
arr3 = NdArray[int, L[2], L[2]]((2, 2), [[5, 6], [7, 8]])

assert arr1.array_equal(arr2)
assert not arr1.array_equal(arr3)

重要なのが型が異なるNdArrayはtype check(!= runtime)時に型エラーになる。

arr1 = NdArray[int, L[2], L[2]]((2, 2), [[1, 2], [3, 4]])
arr2 = NdArray[int, L[2], L[3]]((2, 3), [[5, 6, 7], [8, 9, 10]])             

#  error: Argument of type "NdArray[int, Literal[2], Literal[3]]" cannot be assigned to parameter "other" of type "NdArray[int, Literal[2], Literal[2]]" in function "array_equal"
#   "NdArray[int, Literal[2], Literal[3]]" is incompatible with "NdArray[int, Literal[2], Literal[2]]"
#   Type parameter "Shape@NdArray" is invariant, but "*tuple[Literal[2], Literal[3]]" is not the same as "*tuple[Literal[2], Literal[2]]" (reportGeneralTypeIssues)
arr1.array_equal(arr2)  # type: ignore[reportGeneralTypeIssues]

のでプログラムを走らせる前に静的にそもそも比較ができるのかどうかが型で検査できる。

四則演算

numpyのndarrayは+などの四則演算をサポートしている。
それらはブロードキャストによる異なる形の配列同士の計算を許容しているが、ここで作るNdArrayに対しては同じ形をした配列同士の計算しか許さないことにする。

たとえば+(__add__)の実装は以下の様になる。

class NdArray(Generic[DType, *Shape]):
    ...
    def __add__(self, other: "NdArray[DType, *Shape]") -> "NdArray[DType, *Shape]":
        return NdArray(self.shape, (self.data + other.data).tolist())

かなり雑な実装なのでパフォーマンスなどは考えていない。

これを使えば型レベルでaddできるNdArrayの検査ができるようになる。

arr1 = NdArray[int, L[2], L[2]]((2, 2), [[1, 2], [3, 4]])
arr2 = NdArray[int, L[2], L[2]]((2, 2), [[5, 6], [7, 8]])

arr3 = arr1 + arr2

assert arr3.array_equal(NdArray((2, 2), [[6, 8], [10, 12]]))

arr1 = NdArray[int, L[2], L[2]]((2, 2), [[1, 2], [3, 4]])
arr2 = NdArray[int, L[2], L[3]]((2, 3), [[5, 6, 7], [8, 9, 10]])

# error: Operator "+" not supported for types "NdArray[int, Literal[2], Literal[2]]" and "NdArray[int, Literal[2], Literal[3]]" (reportGeneralTypeIssues)
arr1 + arr2

他の四則演算についても同様に実装すれば型レベルの検査ができるようになる。

def __sub__(self, other: "NdArray[DType, *Shape]") -> "NdArray[DType, *Shape]":
    return NdArray(self.shape, (self.data - other.data).tolist())

def __mul__(self, other: "NdArray[DType, *Shape]") -> "NdArray[DType, *Shape]":
    return NdArray(self.shape, (self.data * other.data).tolist())

# floatベタ書きはかなり微妙な気がする
def __truediv__(self, other: "NdArray[DType, *Shape]") -> "NdArray[float, *Shape]":
    return NdArray(self.shape, (self.data / other.data).tolist())

in

numpyのndarrayにはinが使えて、これである要素が対象のndarrayに含まれるかどうかを検査できる。
inの実装はTypeVarTupleとは関係ないのだが、DTypeを型に含めていることによってこれも静的に型で検査できる。

class NdArray(Generic[DType, *Shape]):
    ...
    def __contains__(self, item: DType) -> bool:
        return item in self.data
arr = NdArray[int, L[2], L[2]]((2, 2), [[1, 2], [3, 4]])

assert 1 in arr
assert 5 not in arr

# Operator "in" not supported for types "Literal['1']" and "NdArray[int, Literal[2], Literal[2]]" (reportGeneralTypeIssues)
"1" in arr  # type: ignore[reportGeneralTypeIssues]

reshape

配列の次元を操作する関数。
NdArrayでも次元をtupleで指定してあげれば型がついた状態で変換が行える。
指定するshapeの型変数はNdArrayで使用しているものとは別のものを使用する必要がある。

class NdArray(Generic[DType, *Shape]):
    ...
    def reshape(self, shape: tuple[*Ts]) -> "NdArray[DType, *Ts]":
        return NdArray(shape, self.data.reshape(cast(SupportsIndex, shape)).tolist())
arr: NdArray[int, L[2], L[3]] = NdArray.arange(6).reshape((2, 3))

assert arr.array_equal(NdArray[int, L[2], L[3]]((2, 3), [[0, 1, 2], [3, 4, 5]]))

# error: Argument of type "tuple[Literal[2], Literal[3]]" cannot be assigned to parameter "shape" of type "tuple[*Ts@reshape]" in function "reshape"
#   "Literal[3]" cannot be assigned to type "Literal[2]" (reportGeneralTypeIssues)
arr: NdArray[int, L[2], L[2]] = NdArray.arange(6).reshape((2, 3))  # type: ignore[reportGeneralTypeIssues]

型がうまくつけれられない操作

ここまで型がつけられる操作群を見てきたが、それ以外の殆どの操作には型がうまくつけられない。

例えばtransposeはShapeを逆順にする必要があるが、それが型の上ではうまく表現できない。

Shape = L[2], L[3]

のときtransposeすると

Shape = L[3], L[2]

になってほしいが表現できない

またdeleteなども次元の減少を型の上で表現できないので難しい。

つまり(少なくとも以上の様なやり方では)numpyの操作を完全に型の上で記述するのは難しい。

所感

TypeVarTupleはある程度便利だが、それ自身の制限と他のtype hintの機能が貧弱なので柔軟性に欠ける。
具体的には、

  • TypeVarTupleの他の型の項への分解性能が低いので、型を再構築できない
  • リテラル型に対して型の上の+などが定まっていないので次元の増減をうまく表現できない。

とか。

とはいっても機能が増えていることは確かなで、現状の機能でのナイスな型付け方法は誰かがすでにやってそうな気はするので暇なときにリサーチしたい。

全コード

from typing import TypeVarTuple, Generic, TypeVar, cast, SupportsIndex

import numpy as np

DType = TypeVar("DType")
Shape = TypeVarTuple("Shape")
Ts = TypeVarTuple("Ts")
Len = TypeVar("Len", bound=int)

Data = list[DType] | list["Data[DType]"]

class NdArray(Generic[DType, *Shape]):
    def __init__(self, shape: tuple[*Shape], data: Data[DType]):
        for s in shape:
            if not isinstance(s, int):
                raise TypeError("Shape must be a tuple of integers")

        if not self._check_shape(shape, data):
            raise ValueError("Data is not in the shape specified by the shape argument")

        self.shape = shape
        self.data = np.array(data)

    def _check_shape(self, shape: tuple[*Ts], data: Data[DType]) -> bool:
        """Check if data is in the shape specified by the shape argument."""
        if not isinstance(data, list):
            return False

        match len(shape):
            case 1:
                return len(data) == shape[0]
            case _:
                return len(data) == shape[0] and all(
                    self._check_shape(shape[1:], cast(Data[DType], d)) for d in data
                )

    def reshape(self, shape: tuple[*Ts]) -> "NdArray[DType, *Ts]":
        return NdArray(shape, self.data.reshape(cast(SupportsIndex, shape)).tolist())

    def __add__(self, other: "NdArray[DType, *Shape]") -> "NdArray[DType, *Shape]":
        return NdArray(self.shape, (self.data + other.data).tolist())

    def __sub__(self, other: "NdArray[DType, *Shape]") -> "NdArray[DType, *Shape]":
        return NdArray(self.shape, (self.data - other.data).tolist())

    def __mul__(self, other: "NdArray[DType, *Shape]") -> "NdArray[DType, *Shape]":
        return NdArray(self.shape, (self.data * other.data).tolist())

    # floatベタ書きはかなり微妙な気がする
    def __truediv__(self, other: "NdArray[DType, *Shape]") -> "NdArray[float, *Shape]":
        return NdArray(self.shape, (self.data / other.data).tolist())

    def __contains__(self, item: DType) -> bool:
        return item in self.data

    def array_equal(self, other: "NdArray[DType, *Shape]") -> bool:
        return np.array_equal(self.data, other.data)

    @staticmethod
    def arange(l: Len) -> "NdArray[int, Len]":
        return NdArray((l,), list(np.arange(l)))

参考

info-outline

お知らせ

K.DEVは株式会社KDOTにより運営されています。記事の内容や会社でのITに関わる一般的なご相談に専門の社員がお答えしております。ぜひお気軽にご連絡ください。