ASTを使ってSQLAlchemyのテーブル定義コードを生成する

はじめに

こんにちは、 whosaysni です。 MonotaROでは開発インフラを整備するチームで働いています(この記事書いてる間にデータマーケティング部門に転属になりました)。

最近のちょっとしたツールづくりの話題をお伝えしたいと思います。

簡単にやったことをまとめると

  • SQLAlchemyのテーブル定義を自動生成するプログラムを書いた
  • プログラムを生成するのに抽象構文木を使った
  • コードを整形するのに yapf を使った

です。では始めましょう。

課題: Table() 呼び出しコードを生成する

SQLAlchemy といえば、様々な種類の DB API に対応していて、Python コードから SQL の構文オブジェクト (SQL式) を操作して SQL を生成・実行でき、 ORM (Object-Relational Mapper) のインタフェースも備えている、データベース操作のアーミーナイフのような存在です。 MonotaRO でも、SQLAlchemy の利用が少しづつ始まっています。

そんな中、あるプロジェクトで、既存のデータベーススキーマに対応するテーブル定義(ORMを使わない方)が一揃い欲しくなりました。

SQLAlchemy では、テーブル定義は sqlalchemy.schema.Table() で生成したテーブルオブジェクトで表します:

from sqlalchemy import Column, MetaData, Table

t = Table(
    'UserMaster', MetaData(), 
    Column('id', INTEGER, primary_key=True),
    ...  # 全カラム定義していく
    )
    

ですが、テーブルがたくさんある場合、これを一つ一つ書いていくのは一苦労。 いっぽう、SQLAlchemy には、既存のデータベーススキーマからテーブル定義を抽出する機能もあります:

from sqlalchemy import MetaData, Table, create_engine

e = create_engine('mysql://username:passwd@dbhost/schema_name')
# autoload_with に指定したエンジンからテーブル定義を取得する
t = Table('UserMaster', MetaData(), autoload_with=e)

これはこれで便利なのですが、Table() を呼び出した瞬間、テーブル構造を取得するためのクエリが実行されて地味にオーバヘッドが乗ったり、テーブル構造がソフトウェア側に明に定義されていないという不安があって避けたいという状況でした。

そこで、既存のテーブルからあらかじめスキーマ情報を抽出して、カラム情報を定義した Table() 呼び出しのコードを自動生成することにしました。

Python の抽象構文木オブジェクトからコードを生成する

実のところ、社内にはすでに似たような目的のために作られたツールがあったので転用してみたのですが、このツールは以下のような感じでテンプレートエンジンを使って条件分岐でコードを生成していました:

t = Table({{ tablename }}, MetaData(),
    {% for column in columns %}
    Column({{ column.name }}, {{ column.datatype }}{% if column.has_args %}(...){% endif %}),
    {% endfor %}
    mysql_engine='{{ db_engine_name }}',
    ...
    )

しかしこれではテンプレートのメンテナンスが結構大変です。 読みやすいコードを出力させるのも(スペースや改行の入り方を制御するパズルになってしまって)難ありだし、「関数」や「引数」といった単位でモジュラーに書きづらいです。

方針を変えて、Pythonの抽象構文木からコードを生成することにしました。

Pythonに限らず、プログラミング言語の処理系は、プログラムの書かれたソースコードを解析して、書かれている命令を実行していきます。そして、言語ごとに命令の書き方、すなわち構文が決まっています。たとえば、Pythonの関数呼び出しなら:

関数名 ( 引数1, 引数2, ... , キーワード1=値1, ...)

のように書きますが、これは、構文としては

関数
   引数列
       引数1
       引数2
       ...
   キーワード引数列
       キーワード引数1
         キーワード1
         値1
       ...

のような木構造で表されます。関数の中に別の関数呼び出しがあれば、入れ子の深い構造になっていきます。この構造を構文木といい、 Python では ast モジュールを使って、プログラム構造を解析し、プログラムの実行に関係のない情報(空白とか区切り文字とか)を取り除いた抽象構文木 (Abstract Syntax Tree, AST) データを得られます。

抽象構文木はプログラムの本質的な部分に対応しているので、原理的には、抽象構文木から(見栄えのための改行や空白以外は)プログラムを完全に復元することができるはずです。

Pythonのことだからきっと同じことを考えて実装した人がいるはず・・・と思って探したら、ありました。 astor (https://pypi.python.org/pypi/astor/) です。

astor は、ソースコードからASTへの変換、そしてその逆ができます。

>>> from ast import alias, ImportFrom, Module
>>> import_ast = ImportFrom(module='future', names=[alias('antigravity', None)], level=0)
>>> module_ast = Module(body=[import_ast])
>>> import astor
>>> astor.to_source(module_ast)
'from future import antigravity\n'

これで、ASTさえ作ればコードを生成できそうな気がしてきました。

モジュールの AST 構造を設計する

今回作りたいツールでは、1つのテーブル定義を1つのモジュールに出力する必要がありました。モジュールファイルは、以下のような構成になるはずです:

from sqlalchemy import Column, MetaData, Table
from sqlalchemy.dialects.mysql import ...  # モジュールが依存しているもの諸々

table = Table(
    'テーブル名', MetaData(),
    Column('カラム名1', ...),
    Column('カラム名2', ...),  # テーブルで定義されているカラム
    ...,
    PrimaryKeyConstraint(...),  # テーブルにかかっている制約
    ...,
    mysql_charset='utf-8',  # テーブルオプション
    )

これを AST でざっくり表すと、こんな構造です:

モジュール
    import 文
    import 文
    ...
    代入文
        代入先:table
        代入元:関数呼び出し・Table(...)
            Table()の引数1:テーブル名の文字列
            Table()の引数2:関数呼び出し・MetaData()
            Table()の引数3:関数呼び出し・Column(...)
                Column()の引数1:カラム名の文字列
                Column()の引数2:カラムの型
                ...

なので、やることはこんな感じですね。

  • モジュールのASTを生成する
  • import 文のASTを生成してモジュールに加える
  • 代入文のASTを生成してモジュールに加える

ただし、テーブルを生成して代入する式の中でどんなオブジェクトを使うかが決まらないと、import 文で何を import すべきかは確定しません。したがって、順番はこうなります:

  • ①代入文のASTを生成する。importが必要な名前は覚えておく
  • ②モジュールのASTを生成する
  • ③(①の)import文のASTを生成して加える
  • ④代入文のASTを加える

下のコードは作成した生成器のコードの一部で、 generate() にテーブル名を指定して呼び出すと、 generate_module_node でモジュールの AST を生成し、最終的に astor.to_source() でコードに変換します。

generate_module_node は、さらに代入文や import 文のASTを生成して、モジュールの ASTに加えていきます。このとき、処理中に集めた情報(importする必要のある名前は何か、どのカラムに UNIQUE や PRIMARY KEY 制約がかかっているか、など)を持ちまわるために、Context というオブジェクトを作成して受け渡し、処理の中で持ちまわるようにしています。

(docstringが怪しい英語ですがご愛嬌。当社ではオープンソース化を意識したコードは英語でdocstringやコメントを書き、ビジネスロジック部分の外に出せないコードのコメントは日本語で書くことが多いです)

    def generate(self, table_name, **options):
        """Generate AST node of table definition module.

        This method generates source code of schema definition
        module for given table name. The table structure is
        extracted using sqlalchemy.inspect().

        :param str table_name: table name to generate definition
        :param dict options: extra options for generation process

        :return: source code of schema definition module
        :rtype: str
        """
        if table_name not in self.list_tables():
            raise ValueError('No such table',
                             dict(table_name=table_name))
        context = Context(table_name, **options)
        mod_node = self.generate_module_node(context)
        return astor.to_source(mod_node)

    def generate_module_node(self, context):
        """Generate AST node of schema definition module.

        :param Context context: processing context

        :return: AST module node of the schema definition module
        :rtype: ast.Module
        """
        body_nodes = []
        # process code first, then collect import information
        prefix_nodes = self.generate_prefix_nodes(context)
        table_def_nodes = self.generate_table_def_nodes(context)
        suffix_nodes = self.generate_suffix_nodes(context)
        import_nodes = self.generate_import_nodes(context)
        body_nodes.extend(import_nodes)
        body_nodes.extend(prefix_nodes)
        body_nodes.extend(table_def_nodes)
        body_nodes.extend(suffix_nodes)
        return Module(body=body_nodes)

AST生成の処理の末端では、地道に AST の要素を組み立てるコードを実行していきます。ASTを使っているかぎり、構文エラーになるような出力はほぼ起きないので、目的のプログラム構造を組み立てる作業に専念できます。小さな構文構造ごとに問題を分けて精密に組めるので、機能追加やテストもしやすい。それに、模型作りのようでちょっと楽しい!

    def generate_import_nodes(self, context):
        """Generate import statement nodes.

        :param Context context: processing context

        :return: list of AST import nodes
        :rtype: list[ast.Import|ast.ImportFrom]
        """
        imports = [
            Import(names=[alias(name_, alias_)])
            for name_, alias_ in context.imports
        ]
        from_imports = [
            ImportFrom(
                module=mod_name,
                names=[alias(name_, alias_) for name_, alias_ in names],
                level=0)
            for mod_name, names in context.from_imports]
        return imports + from_imports

生成されたコードを整形する

astor.to_source() は「構文として正しい」プログラムを生成するだけなので、テーブル定義の式は以下のようにぎっちり詰め込まれて出力されてしまいます:

>>> from schema_def_generator import SchemaDefModuleGenerator as ModGen
>>> gen = ModGen('mysql://user:pass@localhost/sakila')
>>> print(gen.generate('film'))
from sqlalchemy.dialects.mysql.types import DECIMAL, SMALLINT, TEXT, TIMESTAMP, TINYINT, VARCHAR, YEAR
from sqlalchemy import Column, Index, MetaData, PrimaryKeyConstraint, Table
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.dialects.mysql.enumerated import ENUM, SET
t = Table('film', MetaData(), Column('film_id', SMALLINT(display_width=5,
    unsigned=True), nullable=False, autoincrement=True, doc=''), Column(
    'title', VARCHAR(length=255), nullable=False, doc=''), Column(
    'description', TEXT(), nullable=True, doc=''), Column('release_year',
    YEAR(display_width=4), nullable=True, doc=''), Column('language_id',
    TINYINT(display_width=3, unsigned=True), nullable=False, doc=''),
    Column('original_language_id', TINYINT(display_width=3, unsigned=True),
    nullable=True, doc=''), Column('rental_duration', TINYINT(display_width
    =3, unsigned=True), nullable=False, server_default='3', doc=''), Column
    ('rental_rate', DECIMAL(precision=4, scale=2), nullable=False,
    server_default='4.99', doc=''), Column('length', SMALLINT(display_width
    =5, unsigned=True), nullable=True, doc=''), Column('replacement_cost',
    DECIMAL(precision=5, scale=2), nullable=False, server_default='19.99',
    doc=''), Column('rating', ENUM('G', 'PG', 'PG-13', 'R', 'NC-17'),
    nullable=True, server_default='G', doc=''), Column('special_features',
    SET(length=17), nullable=True, doc=''), Column('last_update', TIMESTAMP
    (), nullable=False, server_default=TextClause(
    'CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP'), doc=''),
    PrimaryKeyConstraint('film_id'), Index('idx_title', 'title'), Index(
    'idx_fk_language_id', 'language_id'), Index(
    'idx_fk_original_language_id', 'original_language_id'), mysql_engine=
    'InnoDB', mysql_default_charset='utf8')

これでは保守できないので、コードフォーマッタを使って綺麗に整形します。 コードフォーマッタには Google の yapf (https://github.com/google/yapf) を使うことにしました。

>>> from yapf.yapflib.yapf_api import FormatCode
>>> print(FormatCode(gen.generate('film'))[0])
from sqlalchemy.dialects.mysql.types import DECIMAL, SMALLINT, TEXT, TIMESTAMP, TINYINT, VARCHAR, YEAR
from sqlalchemy import Column, Index, MetaData, PrimaryKeyConstraint, Table
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.dialects.mysql.enumerated import ENUM, SET
t = Table(
    'film',
    MetaData(),
    Column(
        'film_id',
        SMALLINT(display_width=5, unsigned=True),
        nullable=False,
        autoincrement=True,
        doc=''),
    Column('title', VARCHAR(length=255), nullable=False, doc=''),
    Column('description', TEXT(), nullable=True, doc=''),
    ...
    mysql_engine='InnoDB',
    mysql_default_charset='utf8')

うん、なかなかいいですね。これなら使えそう。

さいごに

実際のツール開発では、本番系とまったく同じ CREATE TABLE 文を生成できるところまで作りこみ、インデクス情報や制約の出力機能も実装できました。

SQLAlchemy のよくできているところは、SQLで表現できることがらの一つ一つがPythonのオブジェクトとして抽象化されているところですね。その恩恵で、オブジェクトを組み合わせてSQLやスキーマ定義を作ることもできるし、逆にDB上のデータやスキーマ定義も同じオブジェクトの組み合わせで扱えます。

Pythonのライブラリやアプリケーションには、静的なコードを生成する仕組みはあまりポピュラーではないと思います。メタクラスを駆使して動的に構築するのがPython流なのかもしれません。とはいえ、今回のようにASTを使えば、複雑なコード生成も段階的に構築できるし、出力の文法的な正しさを保つ労力も、ASTの制約のおかげで大幅に省けます。