Source code for qualia_plugin_snn.experimenttracking.QualiaDatabase

"""Provide the QualiaDatabase class extension for Spiking Neural Network tracking."""

from __future__ import annotations

import logging
import sqlite3
import sys
from typing import TYPE_CHECKING as NATIVE_TYPE_CHECKING
from typing import Any, Final

from qualia_core.experimenttracking.QualiaDatabase import QualiaDatabase as QualiaDatabaseQualiaCore
from qualia_core.typing import TYPE_CHECKING

from qualia_plugin_snn.learningmodel.pytorch.SNN import SNN

if TYPE_CHECKING:
    from qualia_core.qualia import TrainResult

if NATIVE_TYPE_CHECKING:
    from qualia_plugin_snn.postprocessing.OperationCounter import OperationMetrics

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)


[docs] class QualiaDatabase(QualiaDatabaseQualiaCore): """Extend :class:`qualia_core.experimenttracking.QualiaDatabase.QualiaDatabase` with support for Spiking Neural Networks.""" # Latest schema extension to create fresh tables __sql_schema_snn: Final[str] = """ CREATE TABLE IF NOT EXISTS models_snn ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_id INTEGER, is_snn INTEGER, timesteps INTEGER, UNIQUE(model_id), FOREIGN KEY(model_id) REFERENCES models(id) ); CREATE TABLE IF NOT EXISTS models_operationcounter ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_id INTEGER, syn_acc REAL, syn_mac REAL, addr_acc REAL, addr_mac REAL, total_acc REAL, total_mac REAL, mem_write REAL, mem_read REAL, input_spikerate REAL, output_spikerate REAL, input_count REAL, output_count REAL, input_is_binary INTEGER, output_is_binary INTEGER, UNIQUE(model_id), FOREIGN KEY(model_id) REFERENCES models(id) ); """ # Incremental schema extension upgrades __sql_schema_upgrades_snn: Final[list[str]] = [ """ ALTER TABLE models_snn ADD COLUMN is_snn INTEGER; """, """ CREATE TABLE IF NOT EXISTS models_operationcounter ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_id INTEGER, syn_acc REAL, syn_mac REAL, addr_acc REAL, addr_mac REAL, mem_write REAL, mem_read REAL, input_spikerate REAL, output_spikerate REAL, input_count REAL, output_count REAL, input_is_binary INTEGER, output_is_binary INTEGER, UNIQUE(model_id), FOREIGN KEY(model_id) REFERENCES models(id) ); """, """ ALTER TABLE models_operationcounter ADD COLUMN total_acc REAL; """, """ ALTER TABLE models_operationcounter ADD COLUMN total_mac REAL; """, ] __queries_snn: Final[dict[str, str]] = { 'get_schema_version_snn': "SELECT schema_version FROM plugins WHERE name = 'qualia_plugin_snn'", 'set_schema_version_snn': "INSERT OR REPLACE INTO plugins(name, schema_version) VALUES ('qualia_plugin_snn', :version)", 'insert_model_snn': """INSERT OR REPLACE INTO models_snn(model_id, timesteps, is_snn) VALUES(:model_id, :timesteps, :is_snn)""", 'insert_model_operationcounter': """INSERT OR REPLACE INTO models_operationcounter( model_id, syn_acc, syn_mac, addr_acc, addr_mac, total_acc, total_mac, mem_write, mem_read, input_spikerate, output_spikerate, input_count, output_count, input_is_binary, output_is_binary ) VALUES ( :model_id, :syn_acc, :syn_mac, :addr_acc, :addr_mac, :total_acc, :total_mac, :mem_write, :mem_read, :input_spikerate, :output_spikerate, :input_count, :output_count, :input_is_binary, :output_is_binary )""", 'get_model_snn': 'SELECT * from models_snn WHERE model_id = :model_id', 'get_model_operationcounter': 'SELECT * from models_operationcounter WHERE model_id = :model_id', } def __set_schema_version_snn(self, cur: sqlite3.Cursor, version: int) -> None: _ = cur.execute(self.__queries_snn['set_schema_version_snn'], {'version': version}) def __get_schema_version_snn(self, cur: sqlite3.Cursor) -> int | None: res = cur.execute(self.__queries_snn['get_schema_version_snn']).fetchone() return res[0] if res is not None else None def __upgrade_database_schema_snn(self, con: sqlite3.Connection, cur: sqlite3.Cursor) -> None: current_version = self.__get_schema_version_snn(cur) latest_version = self.__sql_schema_version_snn logger.info('Current database SNN schema extension version: %s, latest schema version: %d', current_version, latest_version) # Initialize schema extension with fresh tables if current_version is None: _ = cur.execute('BEGIN') # Begin transaction to only update version number if schema upgrade succeeded _ = cur.executescript(self.__sql_schema_snn) self.__set_schema_version_snn(cur, self.__sql_schema_version_snn) con.commit() return # Upgrade existing tables for i, sql_schema_upgrade in enumerate(self.__sql_schema_upgrades_snn[current_version:latest_version]): new_version = current_version + i + 1 logger.info('Upgrading database schema to version %d', new_version) try: _ = cur.execute('BEGIN') # Begin transaction to only update version number if schema upgrade succeeded _ = cur.execute(sql_schema_upgrade) self.__set_schema_version_snn(cur, new_version) con.commit() except sqlite3.Error: con.rollback() logger.exception('Could not upgrade database schema to version %d', new_version) @override def _upgrade_database_schema(self, con: sqlite3.Connection, cur: sqlite3.Cursor) -> None: super()._upgrade_database_schema(con, cur) self.__upgrade_database_schema_snn(con, cur) def __get_model_snn(self, cur: sqlite3.Cursor, model_id: int) -> dict[str, Any] | None: res = cur.execute(self.__queries_snn['get_model_snn'], {'model_id': model_id}).fetchone() return res if res is not None else None def __get_model_operationcounter(self, cur: sqlite3.Cursor, model_id: int) -> dict[str, Any] | None: res = cur.execute(self.__queries_snn['get_model_operationcounter'], {'model_id': model_id}).fetchone() return res if res is not None else None
[docs] @override def log_trainresult(self, trainresult: TrainResult) -> int | None: model_id = super().log_trainresult(trainresult) if not self._con or not self._cur: logger.error('Database not initialized') return None is_snn = 1 if isinstance(trainresult.model, SNN) or getattr(trainresult.model, 'is_snn', False) else 0 snn_metadata = { 'model_id': model_id, 'is_snn': is_snn, 'timesteps': getattr(trainresult.model, 'timesteps', 1), } _ = self._cur.execute(self.__queries_snn['insert_model_snn'], snn_metadata) self._con.commit() return model_id
[docs] def log_operationcounter(self, model_hash: str, oms: list[OperationMetrics]) -> None: """Record :class:`qualia_plugin_snn.postprocessing.OperationCounter.OperationCounter` Total result in database. :param model_hash: hash of model to associate database record to :param oms: :class:`qualia_plugin_snn.postprocessing.OperationCounter.OperationCounter` results """ if not self._con or not self._cur: logger.error('Database not initialized') return model_id = self._lookup_model_hash(self._cur, model_hash) if model_id is None: logger.error('Model hash %s not found', model_hash) return om_total = next((om for om in oms if om.name == 'Total'), None) if not om_total: logger.error('Could not find Total in OperationMetrics') return operationcounter_data = {'model_id': model_id, **om_total.asdict()} _ = self._cur.execute(self.__queries_snn['insert_model_operationcounter'], operationcounter_data) self._con.commit()
def __print_model_operationcounter(self, operationcounter: dict[str, Any]) -> str: max_name_length = max(len(k) for k in operationcounter) operationcounter.pop('id') operationcounter.pop('model_id') s = 'Operation counter:\n' for k, v in dict(operationcounter).items(): s += f' {k}: {" " * (max_name_length - len(k))}{v}\n' return s @override def _print_model(self, model: dict[str, Any]) -> str: s = super()._print_model(model) if not self._cur: logger.error('Database not initialized') return '' model_id = model['id'] model_snn = self.__get_model_snn(self._cur, model_id=model_id) if model_snn: is_snn = bool(model_snn['is_snn']) s += f'SNN: {is_snn}\n' s += f'Timesteps: {model_snn["timesteps"]}\n' model_operationcounter = self.__get_model_operationcounter(self._cur, model_id=model_id) if model_operationcounter: s += self.__print_model_operationcounter(dict(model_operationcounter)) return s @property def __sql_schema_version_snn(self) -> int: return len(self.__sql_schema_upgrades_snn)