diff --git a/cogs/commands/market.py b/cogs/commands/market.py index a8dd9e4d..1580971a 100644 --- a/cogs/commands/market.py +++ b/cogs/commands/market.py @@ -1,5 +1,6 @@ import heapq import time +from collections import defaultdict from discord.ext import commands from discord.ext.commands import Bot, Context, check, clean_content @@ -14,11 +15,12 @@ class Order: - def __init__(self, price, order_type, user_id): + def __init__(self, price, order_type, user_id, qty, order_time): self.user_id = user_id self.price = price self.order_type = order_type - self.order_time = time.time() + self.qty = qty + self.order_time = order_time def __lt__(self, other): if self.order_type == 'ask': @@ -37,7 +39,7 @@ def __gt__(self, other): return self.price < other.price or (self.price == other.price and self.order_time > other.order_time) def __str__(self): - return f'{self.order_type} <@{self.price}> <@{self.user_id}>' + return f'{self.order_type} {self.qty}@<@{self.price}> <@{self.user_id}>' class Market: def __init__(self, stock_name): @@ -50,23 +52,22 @@ def __init__(self, stock_name): self.last_trade = None self.open = True - def bid(self, price, user_id): - self.bids.append(Order(price, 'bid', user_id)) - heapq.heapify(self.bids) + def bid(self, price, user_id, qty, order_time=None): + order_time = time.time() if order_time is None else order_time + heapq.heappush(self.bids, Order(price, 'bid', user_id, qty, order_time)) return self.match() - def ask(self, price, user_id): - self.asks.append(Order(price, 'ask', user_id)) - heapq.heapify(self.asks) + def ask(self, price, user_id, qty, order_time=None): + order_time = time.time() if order_time is None else order_time + heapq.heappush(self.asks, Order(price, 'ask', user_id, qty, order_time)) return self.match() def match(self): - if len(self.bids) == 0 or len(self.asks) == 0: - return None - - if self.bids[0].price >= self.asks[0].price: + matched = [] + while len(self.bids) > 0 and len(self.asks) > 0 and self.bids[0].price >= self.asks[0].price: bid = heapq.heappop(self.bids) ask = heapq.heappop(self.asks) + qty = min(bid.qty, ask.qty) if bid.user_id not in self.trade_history: self.trade_history[bid.user_id] = [] @@ -76,33 +77,33 @@ def match(self): earliest_trade = min(bid, ask, key=lambda x: x.order_time) - bid.price = earliest_trade.price - ask.price = earliest_trade.price + bought = Order(earliest_trade.price, 'bid', bid.user_id, qty, bid.order_time) + sold = Order(earliest_trade.price, 'ask', ask.user_id, qty, ask.order_time) - self.trade_history[bid.user_id].append(bid) - self.trade_history[ask.user_id].append(ask) - - self.last_trade = f"<@{bid.user_id}> bought from <@{ask.user_id}> at {bid.price}" + self.trade_history[bid.user_id].append(bought) + self.trade_history[ask.user_id].append(sold) - return self.last_trade - return None + self.last_trade = f"<@{bid.user_id}> bought {qty} from <@{ask.user_id}> at {bought.price}" + + if ask.qty > qty: + heapq.heappush(self.asks, Order(ask.price, 'ask', ask.user_id, ask.qty - qty, ask.order_time)) + elif bid.qty > qty: + heapq.heappush(self.bids, Order(bid.price, 'bid', bid.user_id, bid.qty - qty, bid.order_time)) + + matched.append(self.last_trade) + + return "\n".join(matched) if len(matched) > 0 else None def close_market(self, valuation): user_to_profit = {} for user in self.trade_history: - user_valuation = 0 - for trade in self.trade_history[user]: - if trade.order_type == 'bid': - user_valuation -= trade.price - user_valuation += valuation - else: - user_valuation += trade.price - user_valuation -= valuation - - user_to_profit[user] = user_valuation - + closing = valuation * sum(trade.qty if trade.order_type == 'bid' else -trade.qty for trade in self.trade_history[user]) + # Note: accumulating _value_ not position, so signs are reversed + pnl = sum(trade.price * (trade.qty if trade.order_type == 'ask' else -trade.qty) for trade in self.trade_history[user]) + user_to_profit[user] = closing + pnl + self.open = False return user_to_profit @@ -135,13 +136,17 @@ def __str__(self): ret_str = "Market is: " ret_str += "OPEN\n\n" if self.open else "CLOSED\n\n" - # Count bids and asks for each price level - bid_counts = {} - ask_counts = {} + # Count bids and asks and sum quantity for each price level + bid_counts = defaultdict(lambda: [0,0]) + ask_counts = defaultdict(lambda: [0,0]) for bid in self.bids: - bid_counts[bid.price] = bid_counts.get(bid.price, 0) + 1 + level = bid_counts[bid.price] + level[0] += 1 + level[1] += bid.qty for ask in self.asks: - ask_counts[ask.price] = ask_counts.get(ask.price, 0) + 1 + level = ask_counts[ask.price] + level[0] += 1 + level[1] += ask.qty # Get price levels; highest first all_prices = sorted(set(bid_counts.keys()).union(set(ask_counts.keys())), reverse=True) @@ -152,13 +157,13 @@ def __str__(self): order_book_lines.append("No outstanding orders\n") else: order_book_lines.append("```") - order_book_lines.append(f"{'Bid Volume':<15} | {'Price':<10} | {'Ask Volume'}") + order_book_lines.append(f"{'Bid Orders':<15} | {'Bid Volume':<15} | {'Price':<10} | {'Ask Volume':<15} | {'Ask Orders'}") for price in all_prices: - bid_vol = bid_counts.get(price, " " * 15) - ask_vol = ask_counts.get(price, " " * 10) + bid_vol = bid_counts.get(price, [" " * 15] * 2) + ask_vol = ask_counts.get(price, [" " * 10] * 2) formatted_price = f"{price:.2f}" - order_book_lines.append(f"{str(bid_vol):<15} | {str(formatted_price):<10} | {str(ask_vol)}") + order_book_lines.append(f"{str(bid_vol[0]):<15} | {str(bid_vol[1]):<15} | {str(formatted_price):<10} | {str(ask_vol[1]):<15} | {str(ask_vol[0])}") order_book_lines.append("```") @@ -200,9 +205,9 @@ async def view_market(self, ctx: Context, *, market: clean_content): await ctx.reply(market_str, ephemeral=True) @commands.hybrid_command(help=LONG_HELP_TEXT, brief=SHORT_HELP_TEXT) - async def bid_market(self, ctx: Context, price: float, *, market: clean_content): + async def bid_market(self, ctx: Context, price: float, qty: int, *, market: clean_content): """You would place a bid by using this command - '!bid_market 100 "AAPL"' + '!bid_market 123.4 15 "AAPL"' """ if market not in self.live_markets: await ctx.reply("Market does not exist", ephemeral=True) @@ -214,7 +219,7 @@ async def bid_market(self, ctx: Context, price: float, *, market: clean_content) await ctx.reply("Market is closed", ephemeral=True) return - did_trade = market_obj.bid(price, ctx.author.id) + did_trade = market_obj.bid(price, ctx.author.id, qty) await ctx.reply("Bid placed", ephemeral=True) @@ -222,7 +227,11 @@ async def bid_market(self, ctx: Context, price: float, *, market: clean_content) await ctx.reply(did_trade, ephemeral=False) @commands.hybrid_command(help=LONG_HELP_TEXT, brief=SHORT_HELP_TEXT) - async def ask_market(self, ctx: Context, price: float, *, market: clean_content): + async def ask_market(self, ctx: Context, price: float, qty: int, *, market: clean_content): + """You would place an ask by using this command + '!ask_market 123.4 15 "AAPL"' + """ + if market not in self.live_markets: await ctx.reply("Market does not exist", ephemeral=True) return @@ -234,7 +243,7 @@ async def ask_market(self, ctx: Context, price: float, *, market: clean_content) return - did_trade = market_obj.ask(price, ctx.author.id) + did_trade = market_obj.ask(price, ctx.author.id, qty) await ctx.reply("Ask placed", ephemeral=True) @@ -250,14 +259,14 @@ async def positions_market(self, ctx: Context, *, market: clean_content): market_obj = self.live_markets[market] user_trades = market_obj.trade_history.get(ctx.author.id, []) - user_asks = [trade.price for trade in user_trades if trade.order_type == 'ask'] - user_bids = [trade.price for trade in user_trades if trade.order_type == 'bid'] + user_asks = "\n".join(f"{trade.qty}@{trade.price}" for trade in user_trades if trade.order_type == 'ask') + user_bids = "\n".join(f"{trade.qty}@{trade.price}" for trade in user_trades if trade.order_type == 'bid') + net = sum(trade.qty if trade.order_type == 'bid' else -trade.qty for trade in user_trades) positions = f"Positions for <@{ctx.author.id}> in {market_obj.stock_name}\n" - positions += "Bids\n" - positions += "\n".join([str(bid) for bid in user_bids]) - positions += "\n\nAsks\n" - positions += "\n".join([str(ask) for ask in user_asks]) + positions += f"Net position: {net}\n" + positions += f"Bids\n{user_bids}" + positions += f"Asks\n{user_asks}" await ctx.reply(str(positions), ephemeral=True) diff --git a/tests/test_market.py b/tests/test_market.py new file mode 100644 index 00000000..c532a86b --- /dev/null +++ b/tests/test_market.py @@ -0,0 +1,86 @@ +from cogs.commands.market import Market + + +def test_can_place_orders(): + m = Market("TEST") + assert m.ask(102, 1, 1, 1) is None + assert m.bid(101, 2, 3, 2) is None + assert m.ask(102, 3, 4, 3) is None + assert len(m.asks) == 2 + assert len(m.bids) == 1 + + assert str(m) == """Market is: OPEN + +📊 **TEST Order Book** 📊 +``` +Bid Orders | Bid Volume | Price | Ask Volume | Ask Orders + | | 102.00 | 5 | 2 +1 | 3 | 101.00 | | +``` +Last Trade: None""" + +def test_single_match(): + m = Market("test") + assert m.ask(101, 1, 1) is None + assert (matched := m.bid(101, 2, 1)) is not None + assert matched == "<@2> bought 1 from <@1> at 101" + assert len(m.asks) == 0 + assert len(m.bids) == 0 + +def test_partial_match(): + m = Market("test") + assert m.ask(102, 1, 100, 1) is None + assert (o := m.bid(102, 2, 50, 2)) is not None + assert o == "<@2> bought 50 from <@1> at 102" + assert len(m.bids) == 0 + assert len(m.asks) == 1 + assert m.asks[0].qty == 50 + assert m.asks[0].order_time == 1 + +def test_multi_match(): + m = Market("test") + assert m.ask(102, 1, 1, 1) is None + assert m.ask(102, 2, 1, 2) is None + assert m.ask(102, 3, 1, 4) is None + assert m.bid(102, 4, 2, 5) == """<@4> bought 1 from <@1> at 102 +<@4> bought 1 from <@2> at 102""" + assert len(m.bids) == 0 + assert len(m.asks) == 1 + assert m.asks[0].user_id == 3 + assert len(m.trade_history[1]) == 1 + assert len(m.trade_history[2]) == 1 + assert 3 not in m.trade_history + assert len(m.trade_history[4]) == 2 + +def test_turning(): + m = Market("test") + assert m.ask(102, 1, 1, 1) is None + assert m.bid(102, 2, 100, 2) == """<@2> bought 1 from <@1> at 102""" + assert len(m.asks) == 0 + assert len(m.bids) == 1 + assert m.bids[0].qty == 99 + assert m.bids[0].order_time == 2 + +def test_multi_level_clear(): + m = Market("test") + assert m.ask(100, 1, 1, 1) is None + assert m.ask(101, 1, 1, 2) is None + assert m.ask(102, 1, 1, 3) is None + assert m.ask(103, 1, 100, 4) is None + assert m.bid(103, 2, 10, 5) == """<@2> bought 1 from <@1> at 100 +<@2> bought 1 from <@1> at 101 +<@2> bought 1 from <@1> at 102 +<@2> bought 7 from <@1> at 103""" + assert len(m.bids) == 0 + assert len(m.asks) == 1 + assert m.asks[0].qty == 93 + assert m.asks[0].price == 103 + assert len(m.trade_history[1]) == 4 + assert len(m.trade_history[2]) == 4 + +def test_times(): + m = Market("test") + assert m.ask(100, 1 ,1) is None + assert m.ask(101, 2, 1) is None + assert len(set([o.order_time for o in m.asks])) == len(m.asks) + assert m.bid(100, 2, 1) == "<@2> bought 1 from <@1> at 100"