Source code for aqtlib.porter

"""Porter class definition."""

import os
import sys
import asyncio
import argparse
import logging

import pandas as pd
import numpy as np

from sqlalchemy.engine.url import URL
from sqlalchemy import select, and_
from datetime import datetime
from typing import List, Awaitable
from aqtlib import Object, utils
from apgsa import PG
from ib_insync import IB, Forex

from .schema import metadata, symbols, bars, ticks


__all__ = ['Porter']


[docs]class Porter(Object): """Porter class initilizer Args: symbols : str IB contracts CSV database (default: ./symbols.csv) ib_port : int TWS/GW Port to use (default: 4002) ib_client : int TWS/GW Client ID (default: 100) ib_server : str IB TWS/GW Server hostname (default: localhost) db_host : str PostgreSQL server hostname (default: localhost) db_port : str PostgreSQL server port (default: 3306) db_name : str PostgreSQL server database (default: aqtlib_db) db_user : str PostgreSQL server username (default: aqtlib_user) db_pass : str PostgreSQL server password (default: aqtlib_pass) db_skip : str Skip PostgreSQL logging (default: False) """ RequestTimeout = 0 defaults = dict( symbols='sybmols.csv', ib_port=4002, # 7496/7497 = TWS, 4001/4002 = IBGateway ib_client=100, ib_server='localhost', db_host='localhost', db_port=5432, db_name='aqtlib_db', db_user='aqtlib_user', db_pass='aqtlib_pass', db_skip=False ) # __slots__ = defaults.keys() def __init__(self, *args, **kwargs): Object.__init__(self, *args, **kwargs) # initilize class logger self._logger = logging.getLogger(__name__) # override with (non-default) command-line args self.update(**self.load_cli_args()) # database manager self.pg = PG() # sync/async framework for Interactive Brokers self.ib = IB() self.ib.pendingTickersEvent += self.onPendingTickers # do not act on first tick (incorrect) self.first_tick = True self._loop = asyncio.get_event_loop() def onPendingTickers(self, tickers): """ Handling and recording tickers form Interactive Brokers. """ # do not act on first incorrect tick if self.first_tick: self.first_tick = False return fields = ['bid', 'bidSize', 'ask', 'askSize', 'time'] clip_tickers_attrs_generator = ( # retrive sub attributes from the sequence of Ticker objects. # a list of fields is given and only retain those fields. {k: v for k, v in ticker.dict().items() if k in fields} for ticker in tickers) data = list(clip_tickers_attrs_generator) asyncio.ensure_future(self.pg.execute(ticks.insert().values(data))) # ------------------------------------------- def load_cli_args(self): parser = argparse.ArgumentParser( description='AQTLib Porter', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--ib_port', default=self.ib_port, help='TWS/GW Port to use', required=False) parser.add_argument('--ib_client', default=self.ib_client, help='TWS/GW Client ID', required=False) parser.add_argument('--ib_server', default=self.ib_server, help='IB TWS/GW Server hostname', required=False) parser.add_argument('--orderbook', action='store_true', help='Get Order Book (Market Depth) data', required=False) parser.add_argument('--db_host', default=self.db_host, help='PostgreSQL server hostname', required=False) parser.add_argument('--db_port', default=self.db_port, help='PostgreSQL server port', required=False) parser.add_argument('--db_name', default=self.db_name, help='PostgreSQL server database', required=False) parser.add_argument('--db_user', default=self.db_user, help='PostgreSQL server username', required=False) parser.add_argument('--db_pass', default=self.db_pass, help='PostgreSQL server password', required=False) parser.add_argument('--db_skip', default=self.db_skip, required=False, help='Skip PostgreSQL logging (flag)', action='store_true') # only return non-default cmd line args # (meaning only those actually given) cmd_args, _ = parser.parse_known_args() args = {arg: val for arg, val in vars( cmd_args).items() if val != parser.get_default(arg)} return args def _run(self, *awaitables: Awaitable): return utils.run(*awaitables, timeout=self.RequestTimeout) def run(self): """Starts the Porter Connects to the TWS/GW, processes and logs market data. """ self._loop.run_forever() def connect_sql(self): # connect to PostgreSQL self.pg.connect( self.db_host, self.db_name, self.db_user, self.db_pass) self._logger.info("PostgreSQL {}:{} Connected.".format(self.db_host, self.db_port)) async def get_symbol_id_async(self, symbol): # start asset_class = utils.gen_asset_class(symbol) symbol_group = utils.gen_symbol_group(symbol) clean_symbol = symbol.replace("_" + asset_class, "") expiry = None async def querySymbolIdAsync(asset_class, symbol_group, clean_symbol): sql = select([symbols]).where(and_( symbols.c.symbol == clean_symbol, symbols.c.symbol_group == symbol_group, symbols.c.asset_class == asset_class) ) return await self.pg.fetchrow(sql) # symbol already in db row = await querySymbolIdAsync(asset_class, symbol_group, clean_symbol) if row is not None: return row[0] # symbol/expiry not in db... insert new/update expiry else: # need to update the expiry? # TODO: add expiry # insert new symbol data = { 'symbol': clean_symbol, 'symbol_group': symbol_group, 'asset_class': asset_class, 'expiry': expiry } sql = symbols.insert([data]) await self.pg.execute(sql) row = await querySymbolIdAsync(asset_class, symbol_group, clean_symbol) return row[0] async def store_data_async(self, df, kind="BAR"): # validate columns valid_cols = utils.validate_columns(df, kind) if not valid_cols: raise ValueError('Invalid Column list') # loop through symbols and save in db for symbol in list(df['symbol'].unique()): data = df[df['symbol'] == symbol] symbol_id = await self.get_symbol_id_async(symbol) # prepare columns for insert data.loc[:, 'datetime'] = data.index data.loc[:, 'symbol_id'] = symbol_id data = data.drop(['symbol', 'symbol_group', 'asset_class', 'expiry'], axis=1) # insert row by row to handle greeks data = data.to_dict(orient="records") if kind == "BAR": for _, row in enumerate(data): sql = bars.insert().values([row]) await self.pg.execute(sql) else: pass return True # ------------------------------------------- async def get_data_async(self, sql) -> pd.DataFrame: # async with self.pg.pool.acquire() as conn: # stmt = await conn.prepare(sql) data = await self.pg.fetch(sql) if not data: return pd.DataFrame() columns = [k for k in data[0].keys()] return pd.DataFrame(data, columns=columns) def get_history(self, symbols, start, end=None, resolution="1T", tz="UTC", continuous=True): if end is None: end = datetime.now() sql_query = select([ bars.c.datetime, bars.c.open, bars.c.high, bars.c.low, bars.c.close, bars.c.volume]).where( and_(bars.c.datetime >= start, bars.c.datetime <= end)) return utils.run(self.get_data_async(sql_query)) # --------------------------------------------- @staticmethod def validate_csv(df: pd.DataFrame, kind: str = "BAR") -> bool: """ Check if a AQTLib-compatible CSV file. """ _BARS_COLS = ('asset_class', 'open', 'high', 'low', 'close', 'volume') for el in _BARS_COLS: if el not in df.columns: raise ValueError('Column {el} not found'.format(el=el)) return False return True # ------------------------------------------- @staticmethod def prepare_data(instrument, data, output_path=None, index=None, colsmap=None, kind="BAR", resample="1T") -> pd.DataFrame: """ Converts given DataFrame to a AQTLib-compatible format csv file. :Parameters: instrument : mixed IB contract tuple / string (same as that given to strategy) data : pd.DataFrame Pandas DataDrame with that instrument's market data output_path : str Path to where the resulting CSV should be saved (optional) index : pd.Series Pandas Series that will be used for df's index (optioanl) colsmap : dict Dict for mapping df's columns to those used by AQTLib (default assumes same naming convention as AQTLib's) kind : str Is this ``BAR`` or ``TICK`` data resample : str Pandas resolution (defaults to 1min/1T) :Returns: data : pd.DataFrame Pandas DataFrame in a AQTLib-compatible format and timezone """ df = data.copy() # jquant's csv? if set(df.columns) == set(['close', 'open', 'high', 'low', 'volume', 'money']): df.index = df.index.tz_localize(utils.get_timezone()).tz_convert('UTC') # FIXME: generate a valid ib tuple symbol = instrument[0] + '_' + instrument[1] symbol_group = instrument[0] asset_class = instrument[1] df.loc[:, 'symbol'] = symbol df.loc[:, 'symbol_group'] = symbol_group df.loc[:, 'asset_class'] = asset_class # TODO: validate, remove and map columns df.index.rename("datetime", inplace=True) # save csv if output_path is not None: output_path = os.path.expanduser(output_path) output_path = output_path[:-1] if output_path.endswith('/') else output_path df.to_csv("{path}/{symbol}.{kind}.csv".format( path=output_path, symbol=symbol, kind=kind)) return df # ------------------------------------------- @staticmethod def prepare_bars_history(data, resolution="1T", tz=None): # setup dataframe data.set_index('datetime', inplace=True) data.index = pd.to_datetime(data.index, utc=True) # meta data meta_data = data.groupby(["symbol"])[ ['symbol', 'symbol_group', 'asset_class']].last() combined = [] bars_ohlc_dict = { 'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum' } for symbol in meta_data.index.values: bar_dict = {} for col in data[data['symbol'] == symbol].columns: if col in bars_ohlc_dict.keys(): bar_dict[col] = bars_ohlc_dict[col] # convert timezone if tz: data.index = data.index.tz_convert(tz) resampled = data[data['symbol'] == symbol].resample(resolution).apply(bar_dict) # drop NANs resampled.dropna(inplace=True) resampled['symbol'] = symbol resampled['symbol_group'] = meta_data[meta_data.index == symbol]['symbol_group'].values[0] resampled['asset_class'] = meta_data[meta_data.index == symbol]['asset_class'].values[0] combined.append(resampled) data = pd.concat(combined, sort=True) data['volume'] = data['volume'].astype(int) return data # ------------------------------------------- @staticmethod def drip(history, handler): """ Replaying history data, and handling each record. """ try: for i in range(len(history)): handler(history.iloc[i:i + 1]) print("\n\n>>> Backtesting Completed.") except (KeyboardInterrupt, SystemExit): print( "\n\n>>> Interrupted with Ctrl-c...\n\n") print(".\n.\n.\n") sys.exit(1)