"""Algo class definition."""
import os
import sys
import logging
import argparse
import pandas as pd
from datetime import datetime
from aqtlib import utils, Broker, Porter
from aqtlib.objects import DataStore
from abc import abstractmethod
from .instrument import Instrument
__all__ = ['Algo']
[docs]class Algo(Broker):
"""Algo class initilizer.
Args:
instruments : list
List of IB contract tuples.
resolution : str
Desired bar resolution (using pandas resolution: 1T, 1H, etc).
Default is 1T (1min)
bars_window : int
Length of bars lookback window to keep. Defaults to 120
timezone : str
Convert IB timestamps to this timezone (eg. US/Central).
Defaults to UTC
backtest: bool
Whether to operate in Backtest mode (default: False)
start: str
Backtest start date (YYYY-MM-DD [HH:MM:SS[.MS]). Default is None
end: str
Backtest end date (YYYY-MM-DD [HH:MM:SS[.MS]). Default is None
data : str
Path to the directory with AQTLib-compatible CSV files (Backtest)
output: str
Path to save the recorded data (default: None)
"""
defaults = dict(
instruments=[],
resolution="1D",
bars_window=120,
timezone='UTC',
backtest=False,
start=None,
end=None,
data=None,
output=None
)
def __init__(self, instruments, *args, **kwargs):
super(Algo, self).__init__(instruments, *args, **kwargs)
# strategy name
self.name = self.__class__.__name__
# initilize strategy logger
self._logger = logging.getLogger(self.name)
# override args with (non-default) command-line args
self.update(**self.load_cli_args())
self.backtest_csv = self.data
# sanity checks for backtesting mode
if self.backtest:
self._check_backtest_args()
# initilize output file
self.record_ts = None
if self.output:
self.datastore = DataStore(self.output)
self.bars = pd.DataFrame()
self.bar_hashes = {}
# -----------------------------------
# signal collector
self.signals = {}
for sym in self.symbols:
self.signals[sym] = pd.DataFrame()
self.initialize()
# ---------------------------------------
def _check_backtest_args(self):
if self.output is None:
self._logger.error(
"Must provide an output file for Backtest mode")
sys.exit(0)
if self.start is None:
self._logger.error(
"Must provide start date for Backtest mode")
sys.exit(0)
if self.end is None:
self.end = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
if self.backtest_csv is not None:
self.backtest_csv = os.path.expanduser(self.backtest_csv)
if not os.path.exists(self.backtest_csv):
self._logger.error(
"CSV directory cannot be found ({dir})".format(dir=self.backtest_csv))
sys.exit(0)
elif self.backtest_csv.endswith("/"):
self.backtest_csv = self.backtest_csv[:-1]
# ---------------------------------------
def load_cli_args(self):
"""
Parse command line arguments and return only the non-default ones
:Retruns: dict
a dict of any non-default args passed on the command-line.
"""
parser = argparse.ArgumentParser(
description='AQTLib Algorithm',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--backtest', default=self.defaults['backtest'],
help='Work in Backtest mode (flag)',
action='store_true')
parser.add_argument('--start', default=self.defaults['start'],
help='Backtest start date')
parser.add_argument('--end', default=self.defaults['end'],
help='Backtest end date')
parser.add_argument('--data', default=self.defaults['data'],
help='Path to backtester CSV files')
parser.add_argument('--output', default=self.defaults['output'],
help='Path to save the recorded data')
# only return non-default cmd line args
# (meaning only those actually given)
cmd_args, _ = parser.parse_known_args()
args = {arg: val for arg, val in vars(
cmd_args).items() if val != parser.get_default(arg)}
return args
# ---------------------------------------
def run(self):
"""Starts the algo
Connects to the Porter, processes data and passes
bar data to the ``on_bar`` function.
"""
history = pd.DataFrame()
if self.backtest:
self._logger.info('Algo start backtesting...')
# history from csv dir
if self.backtest_csv:
dfs = self._fetch_csv()
# prepare history data
history = Porter.prepare_bars_history(
data=pd.concat(dfs, sort=True),
resolution=self.resolution,
tz=self.timezone
)
history = history[(history.index >= self.start) & (history.index <= self.end)]
else:
# history from porter
import nest_asyncio
nest_asyncio.apply()
# connect to database
self.porter.connect_sql()
history = self.porter.get_history(
symbols=self.symbols,
start=self.start,
end=self.end if self.end else datetime.now(),
resolution=self.resolution,
tz=self.timezone
)
history = utils.prepare_data(('AAPL', 'STK'), history, index=history.datetime)
# optimize pandas
if not history.empty:
history['symbol'] = history['symbol'].astype('category')
history['symbol_group'] = history['symbol_group'].astype('category')
history['asset_class'] = history['asset_class'].astype('category')
# drip history
Porter.drip(history, self._bar_handler)
# ---------------------------------------
def _fetch_csv(self):
"""
Get bars history from AQTLib-compatible csv file.
"""
dfs = []
for symbol in self.symbols:
file = "{data}/{symbol}.{kind}.csv".format(data=self.backtest_csv, symbol=symbol, kind="BAR")
if not os.path.exists(file):
self._logger.error(
"Can't load data for {symbol} ({file} doesn't exist)".format(
symbol=symbol, file=file))
sys.exit(0)
try:
df = pd.read_csv(file)
if not Porter.validate_csv(df, "BAR"):
self._logger.error("{file} isn't a AQTLib-compatible format".format(file=file))
sys.exit(0)
if df['symbol'].values[-1] != symbol:
self._logger.error(
"{file} doesn't content data for {symbol}".format(file=file, symbol=symbol))
sys.exit(0)
dfs.append(df)
except Exception as e:
self._logger.error(
"Error reading data for {symbol} ({errmsg})", symbol=symbol, errmsg=e)
sys.exit(0)
return dfs
# ---------------------------------------
def _bar_handler(self, bar):
"""
Invoked on every bar captured for the selected instrument.
"""
symbol = bar['symbol'].values
if len(symbol) == 0:
return
symbol = symbol[0]
# self_bars = self.bars.copy() # work on copy
self.bars = self._update_window(self.bars, bar,
window=self.bars_window)
# optimize pandas
if len(self.bars) == 1:
self.bars['symbol'] = self.bars['symbol'].astype('category')
self.bars['symbol_group'] = self.bars['symbol_group'].astype('category')
self.bars['asset_class'] = self.bars['asset_class'].astype('category')
# new bar?
hash_string = bar[:1]['symbol'].to_string().translate(
str.maketrans({key: None for key in "\n -:+"}))
this_bar_hash = abs(hash(hash_string)) % (10 ** 8)
newbar = True
if symbol in self.bar_hashes.keys():
newbar = self.bar_hashes[symbol] != this_bar_hash
self.bar_hashes[symbol] = this_bar_hash
if newbar:
if self.bars[(self.bars['symbol'] == symbol) | (
self.bars['symbol_group'] == symbol)].empty:
return
instrument = self.get_instrument(symbol)
if instrument:
self.record_ts = bar.index[0]
self._logger.debug('BAR TIME: {}'.format(self.record_ts))
self.on_bar(instrument)
self.record(bar)
# ---------------------------------------
def _update_window(self, df, data, window=None, resolution=None):
"""
No. of bars to keep.
"""
df = df.append(data, sort=True) if df is not None else data
# return
if window is None:
return df
return self._get_window_per_symbol(df, window)
# ---------------------------------------
@staticmethod
def _get_window_per_symbol(df, window):
"""
Truncate bars window per symbol.
"""
dfs = []
for symbol in list(df["symbol"].unique()):
dfs.append(df[df['symbol'] == symbol][-window:])
return pd.concat(dfs, sort=True).sort_index()
# ---------------------------------------
def get_instrument(self, symbol):
"""
A string subclass that provides easy access to misc
symbol-related methods and information using shorthand.
Call from within your strategy:
``instrument = self.get_instrument("SYMBOL")``
"""
instrument = Instrument(symbol)
instrument.attach_strategy(self)
return instrument
# ---------------------------------------
@abstractmethod
def on_bar(self, instrument):
"""
Invoked on every bar captured for the selected instrument.
This is where you'll write your strategy logic for bar events.
"""
# raise NotImplementedError("Should implement on_bar()")
pass
# ---------------------------------------
@abstractmethod
def initialize(self):
"""
Invoked once when algo starts. Used for when the strategy
needs to initialize parameters upon starting.
"""
# raise NotImplementedError("Should implement initialize()")
pass
def order(self, signal, symbol, quantity=0, **kwargs):
""" Send an order for the selected instrument
:Parameters:
direction : string
Order Type (BUY/SELL, EXIT/FLATTEN)
symbol : string
instrument symbol
quantity : int
Order quantiry
:Optional:
limit_price : float
In case of a LIMIT order, this is the LIMIT PRICE
expiry : int
Cancel this order if not filled after *n* seconds
(default 60 seconds)
order_type : string
Type of order: Market (default),
LIMIT (default when limit_price is passed),
MODIFY (required passing or orderId)
orderId : int
If modifying an order, the order id of the modified order
target : float
Target (exit) price
initial_stop : float
Price to set hard stop
stop_limit: bool
Flag to indicate if the stop should be STOP or STOP LIMIT.
Default is ``False`` (STOP)
trail_stop_at : float
Price at which to start trailing the stop
trail_stop_type : string
Type of traiing stop offset (amount, percent).
Default is ``percent``
trail_stop_by : float
Offset of trailing stop distance from current price
fillorkill: bool
Fill entire quantiry or none at all
iceberg: bool
Is this an iceberg (hidden) order
tif: str
Time in force (DAY, GTC, IOC, GTD). default is ``DAY``
"""
self._logger.debug('ORDER: %s %4d %s %s', signal,
quantity, symbol, kwargs)
position = self.get_positions(symbol)
if signal.upper() == "EXIT" or signal.upper() == "FLATTEN":
if position['position'] == 0:
return
kwargs['symbol'] = symbol
kwargs['quantity'] = abs(position['position'])
kwargs['direction'] = "BUY" if position['position'] < 0 else "SELL"
# print("EXIT", kwargs)
try:
self.record({symbol + '_POSITION': 0})
except Exception as e:
pass
else:
if quantity == 0:
return
kwargs['symbol'] = symbol
kwargs['quantity'] = abs(quantity)
kwargs['direction'] = signal.upper()
# print(signal.upper(), kwargs)
# record
try:
quantity = abs(quantity)
if kwargs['direction'] != "BUY":
quantity = -quantity
self.record({symbol + '_POSITION': quantity + position['position']})
except Exception as e:
pass
# ---------------------------------------
def record(self, *args, **kwargs):
"""Records data for later analysis.
Values will be logged to the file specified via
``--output [file]`` (along with bar data) as
csv/pickle/h5 file.
Call from within your strategy:
``self.record(key=value)``
:Parameters:
** kwargs : mixed
The names and values to record
"""
if self.output:
try:
self.datastore.record(self.record_ts, *args, **kwargs)
except Exception as e:
pass