Source code for abce.db
# Copyright 2012 Davoud Taghawi-Nejad
#
# Module Author: Davoud Taghawi-Nejad
#
# ABCE is open-source software. If you are using ABCE for your research you
# are requested the quote the use of this software.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License and quotation of the
# author. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import multiprocessing
from collections import defaultdict
import dataset
from .online_variance import OnlineVariance
from .postprocess import to_csv
[docs]class Database(multiprocessing.Process):
"""Separate thread that receives data from in_sok and saves it into a
database"""
def __init__(self, directory, in_sok, trade_log, plugin=None, pluginargs=[]):
multiprocessing.Process.__init__(self)
self.directory = directory
self.panels = {}
self.in_sok = in_sok
self.data = {}
self.trade_log = trade_log
self.round = 0
self.plugin = plugin
self.pluginargs = pluginargs
[docs] def run(self):
self.aggregation = defaultdict(lambda: defaultdict(OnlineVariance))
if self.plugin is not None:
self.plugin = self.plugin(*self.pluginargs)
self.dataset_db = dataset.connect('sqlite://')
self.dataset_db.query('PRAGMA synchronous=OFF')
# self.dataset_db.query('PRAGMA journal_mode=OFF')
self.dataset_db.query('PRAGMA count_changes=OFF')
self.dataset_db.query('PRAGMA temp_store=OFF')
self.dataset_db.query('PRAGMA default_temp_store=OFF')
table_log = {}
current_log = defaultdict(list)
current_trade = []
self.table_aggregates = {}
if self.trade_log:
trade_table = self.dataset_db.create_table('trade___trade',
primary_id='index')
while True:
try:
msg = self.in_sok.get()
except KeyboardInterrupt:
print("ADD simulation.finalize() after the simulation command"
"to write the simulation data and AVOID BLOCKING")
break
except EOFError:
break
if msg[0] == 'snapshot_agg':
_, round, group, data_to_write = msg
if self.round == round:
for key, value in list(data_to_write.items()):
self.aggregation[group][key].update(value)
else:
self.make_aggregation_and_write()
self.round = round
for key, value in list(data_to_write.items()):
self.aggregation[group][key].update(value)
elif msg[0] == 'trade_log':
for (good, seller, buyer, price), quantity in list(msg[1].items()):
current_trade.append({'round': msg[2],
'good': good,
'seller': seller,
'buyer': buyer,
'price': price,
'quantity': quantity})
if len(current_trade) == 1000:
trade_table.insert_many(current_trade)
current_trade = []
elif msg[0] == 'log':
_, group, name, round, data_to_write, subround_or_serial = msg
table_name = 'panel___%s___%s' % (group, subround_or_serial)
data_to_write['round'] = str(round)
data_to_write['name'] = str(name)
current_log[table_name].append(data_to_write)
if len(current_log[table_name]) == 1000:
if table_name not in table_log:
table_log[table_name] = self.dataset_db.create_table(
table_name, primary_id='index')
table_log[table_name].insert_many(current_log[table_name])
current_log[table_name] = []
elif msg == "close":
break
else:
try:
getattr(self.plugin, msg[0])(*msg[1], **msg[2])
except AttributeError:
raise AttributeError(
"abce_db error '%s' command unknown" % msg)
for name, data in list(current_log.items()):
if name not in self.dataset_db:
table_log[name] = self.dataset_db.create_table(
name, primary_id='index')
table_log[name].insert_many(data)
self.make_aggregation_and_write()
if self.trade_log:
trade_table.insert_many(current_trade)
self.dataset_db.commit()
try:
self.plugin.close()
except AttributeError:
pass
if self.directory is not None:
to_csv(self.directory, self.dataset_db)
[docs] def make_aggregation_and_write(self):
for group, table in list(self.aggregation.items()):
result = {'round': self.round}
for key, data in list(table.items()):
result[key + '_ttl'] = data.sum()
result[key + '_mean'] = data.mean()
result[key + '_std'] = data.std()
try:
self.table_aggregates[group].insert(result)
except KeyError:
self.table_aggregates[group] = self.dataset_db.create_table(
'aggregate___%s' % group, primary_id='index')
self.table_aggregates[group].insert(result)
self.aggregation[group].clear()