diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8470200 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +tmp/ +.env + +coverage.out \ No newline at end of file diff --git a/README.md b/README.md index 63df5f5..ad299c5 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ A fun implementation of an asset exchange using Go. I am mostly doing this to le - [x] Handle partial fills. - [x] Handle market orders. - [x] Handle multiple assets. -- [ ] Create an API layer to handle orders from customers (net/http). +- [x] Create an API layer to handle orders from customers (net/http). - [ ] Simulate large order flow and experiment with concurrency (Channels + Goroutines). - [ ] Expose orderbook using websockets. - [ ] Simple UI to showcase project. -- [x] Maintain 80%+ code coverage. +- [ ] Maintain 80%+ code coverage. diff --git a/cmd/seed/main.go b/cmd/seed/main.go new file mode 100644 index 0000000..d4b8272 --- /dev/null +++ b/cmd/seed/main.go @@ -0,0 +1,60 @@ +package main + +import ( + "context" + "database/sql" + "log" + "time" + + _ "github.com/lib/pq" + + "exchange/internal/db" +) + +type Pair struct { + Ticker string + Name string +} + +func main() { + connStr := "postgresql://postgres:postgres@localhost:5432/testdb?sslmode=disable" + database, err := sql.Open("postgres", connStr) + if err != nil { + log.Fatalf("could not connect to db: %v", err) + } + queries := db.New(database) + + log.Print("seeding database...") + + // Example seeding assets + assets := []Pair{ + {Ticker: "BTC", Name: "Bitcoin"}, + {Ticker: "ETH", Name: "Ethereum"}, + {Ticker: "LINK", Name: "Chainlink"}, + {Ticker: "MADT", Name: "Tokenized MAD"}, + } + + ctx := context.Background() + + for _, asset := range assets { + _, err := queries.CreateAsset(ctx, db.CreateAssetParams{ + AssetName: asset.Name, + Ticker: asset.Ticker, + }) + if err != nil { + log.Printf("error seeding asset %s: %v", asset.Ticker, err) + } else { + log.Printf("seeded asset: %s", asset.Ticker) + } + } + + queries.CreateUser(ctx, db.CreateUserParams{ + FirstName: "Imad", + LastName: "Archid", + Dob: time.Now(), + Balance: 10000, + Email: "imad@exchange.co", + }) + + log.Println("done.") +} diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..77dc8e9 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,15 @@ +services: + postgres: + image: postgres:16-alpine + container_name: test_postgres + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: testdb + ports: + - "5432:5432" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 5s + timeout: 5s + retries: 5 diff --git a/go.mod b/go.mod index d0c0153..5a144a6 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,8 @@ module exchange go 1.24.2 require github.com/google/uuid v1.6.0 + +require ( + github.com/go-chi/chi/v5 v5.2.1 // indirect + github.com/lib/pq v1.10.9 // indirect +) diff --git a/go.sum b/go.sum index 7790d7c..c677be2 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ +github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= +github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= diff --git a/internal/api/handler/asset.go b/internal/api/handler/asset.go new file mode 100644 index 0000000..1e20877 --- /dev/null +++ b/internal/api/handler/asset.go @@ -0,0 +1,23 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" +) + +func (h *Handler) GetAssets(w http.ResponseWriter, r *http.Request) { + assets, err := h.Queries.GetAllAssets(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(ErrorResponse{ + Code: "ASSET_RETRIEVAL_FAILED", + Message: "Failed to retrieve assets", + }) + fmt.Print(err) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(assets) +} diff --git a/internal/api/handler/order.go b/internal/api/handler/order.go new file mode 100644 index 0000000..5fac101 --- /dev/null +++ b/internal/api/handler/order.go @@ -0,0 +1,92 @@ +package handler + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + + "exchange/internal/db" + "exchange/internal/order" + "exchange/internal/orderbook" + "exchange/internal/types" + + "github.com/google/uuid" +) + +type OrderRequest struct { + Amount int32 `json:"amount"` + Price float64 `json:"price"` + Side types.Side `json:"side"` + OrderType types.OrderType `json:"order_type"` + Ticker string `json:"ticker"` +} + +type ErrorResponse struct { + Code string `json:"error"` + Message string `json:"message"` +} + +type Handler struct { + Queries *db.Queries + OrderBooks map[string]*orderbook.OrderBook + ValidTickers map[string]struct{} +} + +func (h *Handler) SubmitOrder(w http.ResponseWriter, r *http.Request) { + var req OrderRequest + + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Code: "INVALID_REQUEST_BODY", + Message: "Failed to decode request body", + }) + return + } + + s := "6d28e047-27c2-457a-9c4f-d68af05d6c8e" + id, err := uuid.Parse(s) + if err != nil { + log.Fatal("Invalid UUID:", err) + } + + order := order.NewOrder(req.Price, req.Amount, db.OrderSideType(req.Side), db.OrderType(req.OrderType), req.Ticker, id) + if order.IsValid() { + status := h.OrderBooks[req.Ticker].Submit(order, h.Queries) + if status { + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(order) + } else { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Code: "ORDERBOOK_SUBMISSION_FAILED", + Message: "Failed to submit order to orderbook", + }) + } + + } else { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(ErrorResponse{ + Code: "ORDER_NOT_VALID", + Message: "Order is not valid", + }) + } +} + +func (h *Handler) GetOrders(w http.ResponseWriter, r *http.Request) { + orders, err := h.Queries.GetAllOrders(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(ErrorResponse{ + Code: "ORDER_RETRIEVAL_FAILED", + Message: "Failed to retrieve orders", + }) + fmt.Print(err) + return + } + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(orders) +} diff --git a/internal/api/handler/transaction.go b/internal/api/handler/transaction.go new file mode 100644 index 0000000..2aa1c1f --- /dev/null +++ b/internal/api/handler/transaction.go @@ -0,0 +1,39 @@ +package handler + +import ( + "context" + "exchange/internal/db" + "exchange/internal/events" + "fmt" +) + +func StartTransactionPersistenceWorker(queries *db.Queries) { + for event := range events.TransactionEventChan { + + _, err := queries.CreateTransaction(context.Background(), db.CreateTransactionParams{ + BuyerOrder: event.BuyerOrder.ID, + SellerOrder: event.SellerOrder.ID, + Price: event.Price, + Amount: event.Amount, + Asset: event.Asset, + CreatedAt: event.Timestamp, + }) + + if err == nil { + fmt.Print("Buyer: ", event.BuyerOrder, "\n") + queries.UpdateOrderStatus(context.Background(), db.UpdateOrderStatusParams{ + ID: event.BuyerOrder.ID, + OrderStatus: event.BuyerOrder.Status, + }) + fmt.Print("Seller: ", event.SellerOrder, "\n") + + queries.UpdateOrderStatus(context.Background(), db.UpdateOrderStatusParams{ + ID: event.SellerOrder.ID, + OrderStatus: event.SellerOrder.Status, + }) + } else { + fmt.Print(err) + } + + } +} diff --git a/internal/api/middlewares/json_response.go b/internal/api/middlewares/json_response.go new file mode 100644 index 0000000..9cc7bce --- /dev/null +++ b/internal/api/middlewares/json_response.go @@ -0,0 +1,10 @@ +package middlewares + +import "net/http" + +func JSONMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + next.ServeHTTP(w, r) + }) +} diff --git a/internal/api/router/router.go b/internal/api/router/router.go new file mode 100644 index 0000000..96b4266 --- /dev/null +++ b/internal/api/router/router.go @@ -0,0 +1,35 @@ +package router + +import ( + "encoding/json" + "net/http" + + "exchange/internal/api/handler" + "exchange/internal/api/middlewares" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" +) + +func NewRouter(h *handler.Handler) http.Handler { + r := chi.NewRouter() + + r.Use(middleware.Logger) + r.Use(middlewares.JSONMiddleware) + + // Default route + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]string{ + "message": "it works", + }) + }) + + // Orders route + r.Post("/orders", h.SubmitOrder) + r.Get("/orders", h.GetOrders) + + // Assets route + r.Get("/assets", h.GetAssets) + + return r +} diff --git a/internal/db/asset.sql.go b/internal/db/asset.sql.go new file mode 100644 index 0000000..4a096b8 --- /dev/null +++ b/internal/db/asset.sql.go @@ -0,0 +1,110 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: asset.sql + +package db + +import ( + "context" +) + +const CreateAsset = `-- name: CreateAsset :one +INSERT INTO assets (ticker, asset_name) +VALUES ($1, $2) +RETURNING id +` + +type CreateAssetParams struct { + Ticker string `json:"ticker"` + AssetName string `json:"asset_name"` +} + +func (q *Queries) CreateAsset(ctx context.Context, arg CreateAssetParams) (int32, error) { + row := q.db.QueryRowContext(ctx, CreateAsset, arg.Ticker, arg.AssetName) + var id int32 + err := row.Scan(&id) + return id, err +} + +const DeleteAsset = `-- name: DeleteAsset :one +DELETE FROM assets +WHERE ticker = $1 +RETURNING id +` + +func (q *Queries) DeleteAsset(ctx context.Context, ticker string) (int32, error) { + row := q.db.QueryRowContext(ctx, DeleteAsset, ticker) + var id int32 + err := row.Scan(&id) + return id, err +} + +const GetAllAssets = `-- name: GetAllAssets :many +SELECT id, ticker, asset_name, is_tradable FROM assets +` + +func (q *Queries) GetAllAssets(ctx context.Context) ([]Asset, error) { + rows, err := q.db.QueryContext(ctx, GetAllAssets) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Asset{} + for rows.Next() { + var i Asset + if err := rows.Scan( + &i.ID, + &i.Ticker, + &i.AssetName, + &i.IsTradable, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const GetAssetByTicker = `-- name: GetAssetByTicker :one +SELECT id, ticker, asset_name, is_tradable FROM assets +WHERE ticker = $1 +` + +func (q *Queries) GetAssetByTicker(ctx context.Context, ticker string) (Asset, error) { + row := q.db.QueryRowContext(ctx, GetAssetByTicker, ticker) + var i Asset + err := row.Scan( + &i.ID, + &i.Ticker, + &i.AssetName, + &i.IsTradable, + ) + return i, err +} + +const UpdateAsset = `-- name: UpdateAsset :one +UPDATE assets +SET asset_name = $1, is_tradable = $2 +WHERE ticker = $3 +RETURNING id +` + +type UpdateAssetParams struct { + AssetName string `json:"asset_name"` + IsTradable bool `json:"is_tradable"` + Ticker string `json:"ticker"` +} + +func (q *Queries) UpdateAsset(ctx context.Context, arg UpdateAssetParams) (int32, error) { + row := q.db.QueryRowContext(ctx, UpdateAsset, arg.AssetName, arg.IsTradable, arg.Ticker) + var id int32 + err := row.Scan(&id) + return id, err +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000..0c56c2b --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/db/migration/20250428_init.up.sql b/internal/db/migration/20250428_init.up.sql new file mode 100644 index 0000000..254eaac --- /dev/null +++ b/internal/db/migration/20250428_init.up.sql @@ -0,0 +1,50 @@ +CREATE TYPE order_side_type AS ENUM ('BUY', 'SELL'); +CREATE TYPE order_status_type AS ENUM ('PENDING', 'SUBMITTED', 'CANCELED', 'PARTIALLY_FILLED', 'FILLED'); +CREATE TYPE order_type AS ENUM ('MARKET', 'LIMIT'); + +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + first_name VARCHAR(255) NOT NULL, + last_name VARCHAR(255) NOT NULL, + email VARCHAR(255) NOT NULL UNIQUE, + dob DATE NOT NULL, + balance DOUBLE PRECISION NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE assets ( + id SERIAL PRIMARY KEY, + ticker VARCHAR(5) NOT NULL UNIQUE, + asset_name VARCHAR(255) NOT NULL, + is_tradable BOOLEAN NOT NULL DEFAULT TRUE +); + +CREATE TABLE orders ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + price DOUBLE PRECISION NOT NULL, + amount INT NOT NULL, + side order_side_type NOT NULL, + order_type order_type NOT NULL, + asset VARCHAR(5) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_by UUID NOT NULL, + order_status order_status_type NOT NULL DEFAULT 'SUBMITTED', + FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT, + FOREIGN KEY (asset) REFERENCES assets(ticker) +); + +CREATE TABLE transactions ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + price DOUBLE PRECISION NOT NULL, + amount INT NOT NULL, + buyer_order UUID NOT NULL, + seller_order UUID NOT NULL, + asset VARCHAR(5) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (buyer_order) REFERENCES orders(id), + FOREIGN KEY (seller_order) REFERENCES orders(id), + FOREIGN KEY (asset) REFERENCES assets(ticker) +); + diff --git a/internal/db/models.go b/internal/db/models.go new file mode 100644 index 0000000..196fa9a --- /dev/null +++ b/internal/db/models.go @@ -0,0 +1,181 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 + +package db + +import ( + "database/sql/driver" + "fmt" + "time" + + "github.com/google/uuid" +) + +type OrderSideType string + +const ( + OrderSideTypeBUY OrderSideType = "BUY" + OrderSideTypeSELL OrderSideType = "SELL" +) + +func (e *OrderSideType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = OrderSideType(s) + case string: + *e = OrderSideType(s) + default: + return fmt.Errorf("unsupported scan type for OrderSideType: %T", src) + } + return nil +} + +type NullOrderSideType struct { + OrderSideType OrderSideType `json:"order_side_type"` + Valid bool `json:"valid"` // Valid is true if OrderSideType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullOrderSideType) Scan(value interface{}) error { + if value == nil { + ns.OrderSideType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.OrderSideType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullOrderSideType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.OrderSideType), nil +} + +type OrderStatusType string + +const ( + OrderStatusTypePENDING OrderStatusType = "PENDING" + OrderStatusTypeSUBMITTED OrderStatusType = "SUBMITTED" + OrderStatusTypeCANCELED OrderStatusType = "CANCELED" + OrderStatusTypePARTIALLYFILLED OrderStatusType = "PARTIALLY_FILLED" + OrderStatusTypeFILLED OrderStatusType = "FILLED" +) + +func (e *OrderStatusType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = OrderStatusType(s) + case string: + *e = OrderStatusType(s) + default: + return fmt.Errorf("unsupported scan type for OrderStatusType: %T", src) + } + return nil +} + +type NullOrderStatusType struct { + OrderStatusType OrderStatusType `json:"order_status_type"` + Valid bool `json:"valid"` // Valid is true if OrderStatusType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullOrderStatusType) Scan(value interface{}) error { + if value == nil { + ns.OrderStatusType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.OrderStatusType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullOrderStatusType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.OrderStatusType), nil +} + +type OrderType string + +const ( + OrderTypeMARKET OrderType = "MARKET" + OrderTypeLIMIT OrderType = "LIMIT" +) + +func (e *OrderType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = OrderType(s) + case string: + *e = OrderType(s) + default: + return fmt.Errorf("unsupported scan type for OrderType: %T", src) + } + return nil +} + +type NullOrderType struct { + OrderType OrderType `json:"order_type"` + Valid bool `json:"valid"` // Valid is true if OrderType is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullOrderType) Scan(value interface{}) error { + if value == nil { + ns.OrderType, ns.Valid = "", false + return nil + } + ns.Valid = true + return ns.OrderType.Scan(value) +} + +// Value implements the driver Valuer interface. +func (ns NullOrderType) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return string(ns.OrderType), nil +} + +type Asset struct { + ID int32 `json:"id"` + Ticker string `json:"ticker"` + AssetName string `json:"asset_name"` + IsTradable bool `json:"is_tradable"` +} + +type Order struct { + ID uuid.UUID `json:"id"` + Price float64 `json:"price"` + Amount int32 `json:"amount"` + Side OrderSideType `json:"side"` + OrderType OrderType `json:"order_type"` + Asset string `json:"asset"` + CreatedAt time.Time `json:"created_at"` + CreatedBy uuid.UUID `json:"created_by"` + OrderStatus OrderStatusType `json:"order_status"` +} + +type Transaction struct { + ID uuid.UUID `json:"id"` + Price float64 `json:"price"` + Amount int32 `json:"amount"` + BuyerOrder uuid.UUID `json:"buyer_order"` + SellerOrder uuid.UUID `json:"seller_order"` + Asset string `json:"asset"` + CreatedAt time.Time `json:"created_at"` +} + +type User struct { + ID uuid.UUID `json:"id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + Dob time.Time `json:"dob"` + Balance float64 `json:"balance"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/internal/db/order.sql.go b/internal/db/order.sql.go new file mode 100644 index 0000000..1b74f1c --- /dev/null +++ b/internal/db/order.sql.go @@ -0,0 +1,198 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: order.sql + +package db + +import ( + "context" + + "github.com/google/uuid" +) + +const CreateOrder = `-- name: CreateOrder :one +INSERT INTO orders (price, amount, side, order_type, asset, created_by) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id +` + +type CreateOrderParams struct { + Price float64 `json:"price"` + Amount int32 `json:"amount"` + Side OrderSideType `json:"side"` + OrderType OrderType `json:"order_type"` + Asset string `json:"asset"` + CreatedBy uuid.UUID `json:"created_by"` +} + +func (q *Queries) CreateOrder(ctx context.Context, arg CreateOrderParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, CreateOrder, + arg.Price, + arg.Amount, + arg.Side, + arg.OrderType, + arg.Asset, + arg.CreatedBy, + ) + var id uuid.UUID + err := row.Scan(&id) + return id, err +} + +const GetAllOrders = `-- name: GetAllOrders :many +SELECT id, price, amount, side, order_type, asset, created_at, created_by, order_status FROM orders +ORDER BY created_at DESC +` + +func (q *Queries) GetAllOrders(ctx context.Context) ([]Order, error) { + rows, err := q.db.QueryContext(ctx, GetAllOrders) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Order{} + for rows.Next() { + var i Order + if err := rows.Scan( + &i.ID, + &i.Price, + &i.Amount, + &i.Side, + &i.OrderType, + &i.Asset, + &i.CreatedAt, + &i.CreatedBy, + &i.OrderStatus, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const GetOrderById = `-- name: GetOrderById :one +SELECT id, price, amount, side, order_type, asset, created_at, created_by, order_status FROM orders +WHERE id = $1 +` + +func (q *Queries) GetOrderById(ctx context.Context, id uuid.UUID) (Order, error) { + row := q.db.QueryRowContext(ctx, GetOrderById, id) + var i Order + err := row.Scan( + &i.ID, + &i.Price, + &i.Amount, + &i.Side, + &i.OrderType, + &i.Asset, + &i.CreatedAt, + &i.CreatedBy, + &i.OrderStatus, + ) + return i, err +} + +const GetOrdersByUser = `-- name: GetOrdersByUser :many +SELECT id, price, amount, side, order_type, asset, created_at, created_by, order_status FROM orders +WHERE created_by = $1 +ORDER BY created_at DESC +` + +func (q *Queries) GetOrdersByUser(ctx context.Context, createdBy uuid.UUID) ([]Order, error) { + rows, err := q.db.QueryContext(ctx, GetOrdersByUser, createdBy) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Order{} + for rows.Next() { + var i Order + if err := rows.Scan( + &i.ID, + &i.Price, + &i.Amount, + &i.Side, + &i.OrderType, + &i.Asset, + &i.CreatedAt, + &i.CreatedBy, + &i.OrderStatus, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const GetSubmittedOrders = `-- name: GetSubmittedOrders :many +SELECT id, price, amount, side, order_type, asset, created_at, created_by, order_status FROM orders +WHERE order_status = 'SUBMITTED' +ORDER BY created_at DESC +` + +func (q *Queries) GetSubmittedOrders(ctx context.Context) ([]Order, error) { + rows, err := q.db.QueryContext(ctx, GetSubmittedOrders) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Order{} + for rows.Next() { + var i Order + if err := rows.Scan( + &i.ID, + &i.Price, + &i.Amount, + &i.Side, + &i.OrderType, + &i.Asset, + &i.CreatedAt, + &i.CreatedBy, + &i.OrderStatus, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const UpdateOrderStatus = `-- name: UpdateOrderStatus :one +UPDATE orders +SET order_status = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateOrderStatusParams struct { + OrderStatus OrderStatusType `json:"order_status"` + ID uuid.UUID `json:"id"` +} + +func (q *Queries) UpdateOrderStatus(ctx context.Context, arg UpdateOrderStatusParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, UpdateOrderStatus, arg.OrderStatus, arg.ID) + var id uuid.UUID + err := row.Scan(&id) + return id, err +} diff --git a/internal/db/query/asset.sql b/internal/db/query/asset.sql new file mode 100644 index 0000000..45a13a6 --- /dev/null +++ b/internal/db/query/asset.sql @@ -0,0 +1,22 @@ +-- name: CreateAsset :one +INSERT INTO assets (ticker, asset_name) +VALUES ($1, $2) +RETURNING id; + +-- name: DeleteAsset :one +DELETE FROM assets +WHERE ticker = $1 +RETURNING id; + +-- name: UpdateAsset :one +UPDATE assets +SET asset_name = $1, is_tradable = $2 +WHERE ticker = $3 +RETURNING id; + +-- name: GetAssetByTicker :one +SELECT * FROM assets +WHERE ticker = $1; + +-- name: GetAllAssets :many +SELECT * FROM assets; diff --git a/internal/db/query/order.sql b/internal/db/query/order.sql new file mode 100644 index 0000000..fd28e1f --- /dev/null +++ b/internal/db/query/order.sql @@ -0,0 +1,28 @@ +-- name: CreateOrder :one +INSERT INTO orders (price, amount, side, order_type, asset, created_by) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id; + +-- name: GetOrderById :one +SELECT * FROM orders +WHERE id = $1; + +-- name: GetOrdersByUser :many +SELECT * FROM orders +WHERE created_by = $1 +ORDER BY created_at DESC; + +-- name: GetAllOrders :many +SELECT * FROM orders +ORDER BY created_at DESC; + +-- name: UpdateOrderStatus :one +UPDATE orders +SET order_status = $1 +WHERE id = $2 +RETURNING id; + +-- name: GetSubmittedOrders :many +SELECT * FROM orders +WHERE order_status = 'SUBMITTED' +ORDER BY created_at DESC; \ No newline at end of file diff --git a/internal/db/query/transaction.sql b/internal/db/query/transaction.sql new file mode 100644 index 0000000..7bff010 --- /dev/null +++ b/internal/db/query/transaction.sql @@ -0,0 +1,18 @@ +-- name: CreateTransaction :one +INSERT INTO transactions (price, amount, buyer_order, seller_order, asset, created_at) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id; + +-- name: GetTransactionById :one +SELECT * FROM transactions +WHERE id = $1; + +-- name: GetTransactionsByUser :many +SELECT * FROM transactions +JOIN orders ON transactions.buyer_order = orders.id OR transactions.seller_order = orders.id +WHERE orders.created_by = $1 +ORDER BY orders.created_at DESC; + +-- name: GetAllTransactions :many +SELECT * FROM transactions +ORDER BY created_at DESC; diff --git a/internal/db/query/user.sql b/internal/db/query/user.sql new file mode 100644 index 0000000..8b83ba3 --- /dev/null +++ b/internal/db/query/user.sql @@ -0,0 +1,24 @@ +-- name: CreateUser :one +INSERT INTO users (first_name, last_name, dob, email, balance) +VALUES ($1, $2, $3, $4, $5) +RETURNING id; + +-- name: DeleteUser :one +DELETE FROM users +WHERE id = $1 +RETURNING id; +-- name: UpdateUser :one +UPDATE users +SET first_name = $1, last_name = $2, dob = $3, balance = $4 +WHERE id = $5 +RETURNING id; + +-- name: UpdateUserBalance :one +UPDATE users +SET balance = $1 +WHERE id = $2 +RETURNING id; + +-- name: GetAllUsers :many +SELECT * FROM users +ORDER BY created_at DESC; diff --git a/internal/db/transaction.sql.go b/internal/db/transaction.sql.go new file mode 100644 index 0000000..ae3912f --- /dev/null +++ b/internal/db/transaction.sql.go @@ -0,0 +1,164 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: transaction.sql + +package db + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const CreateTransaction = `-- name: CreateTransaction :one +INSERT INTO transactions (price, amount, buyer_order, seller_order, asset, created_at) +VALUES ($1, $2, $3, $4, $5, $6) +RETURNING id +` + +type CreateTransactionParams struct { + Price float64 `json:"price"` + Amount int32 `json:"amount"` + BuyerOrder uuid.UUID `json:"buyer_order"` + SellerOrder uuid.UUID `json:"seller_order"` + Asset string `json:"asset"` + CreatedAt time.Time `json:"created_at"` +} + +func (q *Queries) CreateTransaction(ctx context.Context, arg CreateTransactionParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, CreateTransaction, + arg.Price, + arg.Amount, + arg.BuyerOrder, + arg.SellerOrder, + arg.Asset, + arg.CreatedAt, + ) + var id uuid.UUID + err := row.Scan(&id) + return id, err +} + +const GetAllTransactions = `-- name: GetAllTransactions :many +SELECT id, price, amount, buyer_order, seller_order, asset, created_at FROM transactions +ORDER BY created_at DESC +` + +func (q *Queries) GetAllTransactions(ctx context.Context) ([]Transaction, error) { + rows, err := q.db.QueryContext(ctx, GetAllTransactions) + if err != nil { + return nil, err + } + defer rows.Close() + items := []Transaction{} + for rows.Next() { + var i Transaction + if err := rows.Scan( + &i.ID, + &i.Price, + &i.Amount, + &i.BuyerOrder, + &i.SellerOrder, + &i.Asset, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const GetTransactionById = `-- name: GetTransactionById :one +SELECT id, price, amount, buyer_order, seller_order, asset, created_at FROM transactions +WHERE id = $1 +` + +func (q *Queries) GetTransactionById(ctx context.Context, id uuid.UUID) (Transaction, error) { + row := q.db.QueryRowContext(ctx, GetTransactionById, id) + var i Transaction + err := row.Scan( + &i.ID, + &i.Price, + &i.Amount, + &i.BuyerOrder, + &i.SellerOrder, + &i.Asset, + &i.CreatedAt, + ) + return i, err +} + +const GetTransactionsByUser = `-- name: GetTransactionsByUser :many +SELECT transactions.id, transactions.price, transactions.amount, buyer_order, seller_order, transactions.asset, transactions.created_at, orders.id, orders.price, orders.amount, side, order_type, orders.asset, orders.created_at, created_by, order_status FROM transactions +JOIN orders ON transactions.buyer_order = orders.id OR transactions.seller_order = orders.id +WHERE orders.created_by = $1 +ORDER BY orders.created_at DESC +` + +type GetTransactionsByUserRow struct { + ID uuid.UUID `json:"id"` + Price float64 `json:"price"` + Amount int32 `json:"amount"` + BuyerOrder uuid.UUID `json:"buyer_order"` + SellerOrder uuid.UUID `json:"seller_order"` + Asset string `json:"asset"` + CreatedAt time.Time `json:"created_at"` + ID_2 uuid.UUID `json:"id_2"` + Price_2 float64 `json:"price_2"` + Amount_2 int32 `json:"amount_2"` + Side OrderSideType `json:"side"` + OrderType OrderType `json:"order_type"` + Asset_2 string `json:"asset_2"` + CreatedAt_2 time.Time `json:"created_at_2"` + CreatedBy uuid.UUID `json:"created_by"` + OrderStatus OrderStatusType `json:"order_status"` +} + +func (q *Queries) GetTransactionsByUser(ctx context.Context, createdBy uuid.UUID) ([]GetTransactionsByUserRow, error) { + rows, err := q.db.QueryContext(ctx, GetTransactionsByUser, createdBy) + if err != nil { + return nil, err + } + defer rows.Close() + items := []GetTransactionsByUserRow{} + for rows.Next() { + var i GetTransactionsByUserRow + if err := rows.Scan( + &i.ID, + &i.Price, + &i.Amount, + &i.BuyerOrder, + &i.SellerOrder, + &i.Asset, + &i.CreatedAt, + &i.ID_2, + &i.Price_2, + &i.Amount_2, + &i.Side, + &i.OrderType, + &i.Asset_2, + &i.CreatedAt_2, + &i.CreatedBy, + &i.OrderStatus, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} diff --git a/internal/db/user.sql.go b/internal/db/user.sql.go new file mode 100644 index 0000000..38f34f2 --- /dev/null +++ b/internal/db/user.sql.go @@ -0,0 +1,135 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: user.sql + +package db + +import ( + "context" + "time" + + "github.com/google/uuid" +) + +const CreateUser = `-- name: CreateUser :one +INSERT INTO users (first_name, last_name, dob, email, balance) +VALUES ($1, $2, $3, $4, $5) +RETURNING id +` + +type CreateUserParams struct { + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Dob time.Time `json:"dob"` + Email string `json:"email"` + Balance float64 `json:"balance"` +} + +func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, CreateUser, + arg.FirstName, + arg.LastName, + arg.Dob, + arg.Email, + arg.Balance, + ) + var id uuid.UUID + err := row.Scan(&id) + return id, err +} + +const DeleteUser = `-- name: DeleteUser :one +DELETE FROM users +WHERE id = $1 +RETURNING id +` + +func (q *Queries) DeleteUser(ctx context.Context, id uuid.UUID) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, DeleteUser, id) + err := row.Scan(&id) + return id, err +} + +const GetAllUsers = `-- name: GetAllUsers :many +SELECT id, first_name, last_name, email, dob, balance, created_at FROM users +ORDER BY created_at DESC +` + +func (q *Queries) GetAllUsers(ctx context.Context) ([]User, error) { + rows, err := q.db.QueryContext(ctx, GetAllUsers) + if err != nil { + return nil, err + } + defer rows.Close() + items := []User{} + for rows.Next() { + var i User + if err := rows.Scan( + &i.ID, + &i.FirstName, + &i.LastName, + &i.Email, + &i.Dob, + &i.Balance, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const UpdateUser = `-- name: UpdateUser :one +UPDATE users +SET first_name = $1, last_name = $2, dob = $3, balance = $4 +WHERE id = $5 +RETURNING id +` + +type UpdateUserParams struct { + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Dob time.Time `json:"dob"` + Balance float64 `json:"balance"` + ID uuid.UUID `json:"id"` +} + +func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, UpdateUser, + arg.FirstName, + arg.LastName, + arg.Dob, + arg.Balance, + arg.ID, + ) + var id uuid.UUID + err := row.Scan(&id) + return id, err +} + +const UpdateUserBalance = `-- name: UpdateUserBalance :one +UPDATE users +SET balance = $1 +WHERE id = $2 +RETURNING id +` + +type UpdateUserBalanceParams struct { + Balance float64 `json:"balance"` + ID uuid.UUID `json:"id"` +} + +func (q *Queries) UpdateUserBalance(ctx context.Context, arg UpdateUserBalanceParams) (uuid.UUID, error) { + row := q.db.QueryRowContext(ctx, UpdateUserBalance, arg.Balance, arg.ID) + var id uuid.UUID + err := row.Scan(&id) + return id, err +} diff --git a/internal/events/events.go b/internal/events/events.go new file mode 100644 index 0000000..c167698 --- /dev/null +++ b/internal/events/events.go @@ -0,0 +1,30 @@ +package events + +import ( + "exchange/internal/db" + "exchange/internal/order" + "time" + + "github.com/google/uuid" +) + +type OrderEvent struct { + Price float64 + Amount int32 + Side db.OrderSideType + OrderType db.OrderType + Ticker string + CreatedBy uuid.UUID +} + +type TransactionEvent struct { + Price float64 + Amount int32 + BuyerOrder *order.Order + SellerOrder *order.Order + Asset string + Timestamp time.Time +} + +var TransactionEventChan = make(chan TransactionEvent, 10000) +var NewOrderChan = make(chan OrderEvent, 10000) diff --git a/internal/order/order.go b/internal/order/order.go index 75e2090..83d5100 100644 --- a/internal/order/order.go +++ b/internal/order/order.go @@ -1,33 +1,35 @@ package order import ( - "exchange/internal/types" + "exchange/internal/db" "time" "github.com/google/uuid" ) type Order struct { - ID string - Price float64 - Amount float64 - Side types.Side - Time time.Time - Type types.OrderType - Status types.Status - Ticker string + ID uuid.UUID + Price float64 + Amount int32 + Side db.OrderSideType + Time time.Time + Type db.OrderType + Status db.OrderStatusType + Ticker string + CreatedBy uuid.UUID } -func NewOrder(price float64, amount float64, side types.Side, orderType types.OrderType, asset string) *Order { +func NewOrder(price float64, amount int32, side db.OrderSideType, orderType db.OrderType, ticker string, createdBy uuid.UUID) *Order { return &Order{ - ID: uuid.New().String(), - Price: price, - Amount: amount, - Side: side, - Time: time.Now().UTC(), - Type: orderType, - Status: types.Pending, - Ticker: asset, + ID: uuid.New(), + Price: price, + Amount: amount, + Side: side, + Time: time.Now().UTC(), + Type: orderType, + Status: "SUBMITTED", + Ticker: ticker, + CreatedBy: createdBy, } } @@ -35,8 +37,9 @@ func (o *Order) IsValid() bool { if o.Price <= 0 || o.Amount <= 0 { return false } - if o.Side != types.Buy && o.Side != types.Sell { + if o.Side != db.OrderSideTypeBUY && o.Side != db.OrderSideTypeSELL { return false } + return true } diff --git a/internal/order/order_service.go b/internal/order/order_service.go new file mode 100644 index 0000000..52a57da --- /dev/null +++ b/internal/order/order_service.go @@ -0,0 +1,98 @@ +package order + +import ( + "context" + "exchange/internal/db" + "sync" + + "github.com/google/uuid" +) + +type OrderService struct { + orderbook OrderBookInterface + db *db.Queries + mu sync.RWMutex +} + +// OrderBookInterface defines the interface for orderbook operations +type OrderBookInterface interface { + Submit(*Order) bool + Withdraw(*Order) bool +} + +func NewOrderService(orderbook OrderBookInterface, db *db.Queries) *OrderService { + return &OrderService{ + orderbook: orderbook, + db: db, + } +} + +// SubmitOrder handles the submission of a new order, ensuring consistency between memory and database +func (s *OrderService) SubmitOrder(ctx context.Context, o *Order) error { + s.mu.Lock() + defer s.mu.Unlock() + + // First, try to submit to the orderbook + success := s.orderbook.Submit(o) + if !success { + return nil // Order was rejected by the orderbook + } + + // If successful in memory, persist to database + _, err := s.db.CreateOrder(ctx, db.CreateOrderParams{ + Price: o.Price, + Amount: o.Amount, + Side: o.Side, + OrderType: o.Type, + Asset: o.Ticker, + CreatedBy: o.CreatedBy, + }) + + if err != nil { + // If database operation fails, we need to rollback the memory operation + s.orderbook.Withdraw(o) + return err + } + + return nil +} + +// UpdateOrderStatus updates both memory and database order status +func (s *OrderService) UpdateOrderStatus(ctx context.Context, orderID uuid.UUID, status db.OrderStatusType) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Update in database first + _, err := s.db.UpdateOrderStatus(ctx, db.UpdateOrderStatusParams{ + OrderStatus: status, + ID: orderID, + }) + if err != nil { + return err + } + + // If database update successful, update in memory + // Note: This is a simplified version. In a real implementation, you'd need to + // find the order in the appropriate heap and update its status + return nil +} + +// GetOrder retrieves an order from the database +func (s *OrderService) GetOrder(ctx context.Context, orderID uuid.UUID) (*Order, error) { + dbOrder, err := s.db.GetOrderById(ctx, orderID) + if err != nil { + return nil, err + } + + return &Order{ + ID: dbOrder.ID, + Price: dbOrder.Price, + Amount: dbOrder.Amount, + Side: dbOrder.Side, + Time: dbOrder.CreatedAt, + Type: dbOrder.OrderType, + Status: dbOrder.OrderStatus, + Ticker: dbOrder.Asset, + CreatedBy: dbOrder.CreatedBy, + }, nil +} diff --git a/internal/order/order_test.go b/internal/order/order_test.go index aa22bdc..141e335 100644 --- a/internal/order/order_test.go +++ b/internal/order/order_test.go @@ -1,33 +1,34 @@ package order import ( - "exchange/internal/types" + "exchange/internal/db" "testing" + + "github.com/google/uuid" ) func TestNewOrder(t *testing.T) { - order := NewOrder(10, 100, types.Buy, types.Limit, "LINK") - if len(order.ID) == 0 { + if len(NewOrder(10, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()).ID) == 0 { t.Errorf("The order submitted does not have a valid ID") } } func TestOrderNegativeParams(t *testing.T) { - order := NewOrder(-4, 0, types.Buy, types.Limit, "LINK") + order := NewOrder(-4, 0, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) if order.IsValid() == true { t.Errorf("Order has illegal parameters.") } } func TestOrderNoSides(t *testing.T) { - order := NewOrder(10, 100, "S", "B", "LINK") + order := NewOrder(10, 100, "S", "B", "LINK", uuid.New()) if order.IsValid() == true { t.Errorf("Order has illegal parameters.") } } func TestValidOrder(t *testing.T) { - order := NewOrder(10, 100, types.Buy, types.Limit, "LINK") + order := NewOrder(10, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) if order.IsValid() != true { t.Errorf("Expected valid order.") } diff --git a/internal/orderbook/heap.go b/internal/orderbook/heap.go index d43a3f3..429ce33 100644 --- a/internal/orderbook/heap.go +++ b/internal/orderbook/heap.go @@ -2,6 +2,8 @@ package orderbook import ( "exchange/internal/order" + + "github.com/google/uuid" ) type OrderHeap struct { @@ -90,7 +92,7 @@ func (h *OrderHeap) compareOrders(i, j int) int { return 0 } -func (ob *OrderBook) removeFromHeap(h *OrderHeap, id string) bool { +func (ob *OrderBook) removeFromHeap(h *OrderHeap, id uuid.UUID) bool { for i, val := range h.orders { if val.ID == id { lastIndex := len(h.orders) - 1 diff --git a/internal/orderbook/orderbook.go b/internal/orderbook/orderbook.go index ce0af8a..f14a045 100644 --- a/internal/orderbook/orderbook.go +++ b/internal/orderbook/orderbook.go @@ -1,9 +1,12 @@ package orderbook import ( + "context" + "exchange/internal/db" + "exchange/internal/events" "exchange/internal/order" - "exchange/internal/types" "fmt" + "time" ) type OrderBook struct { @@ -20,39 +23,57 @@ func NewOrderBook(ticker string) *OrderBook { } } -func (ob *OrderBook) Submit(o *order.Order) bool { +// Submit implements the OrderBookInterface +func (ob *OrderBook) Submit(o *order.Order, queries *db.Queries) bool { if !o.IsValid() { fmt.Println("Order not valid. WRONG_ORDER") return false } + if ob == nil { + fmt.Println("Ticker does not exist. BAD_TICKER") + return false + } + if o.Ticker != ob.Ticker { fmt.Println("Ticker does not match orderbook. WRONG_TICKER") return false } if !handleOrderType(o, ob) { - fmt.Printf("%s Order was NOT filled %.2f @ %.2f ID: %s\n", + fmt.Printf("%s Order was NOT filled %d @ %.2f ID: %s\n", o.Side, o.Amount, o.Price, o.ID) - o.Status = types.Cancelled + o.Status = db.OrderStatusTypeCANCELED return false } - fmt.Printf("%s %s Order submitted %.2f @ %.2f ID: %s\n", - o.Side, o.Type, o.Amount, o.Price, o.ID) + o_id, _ := queries.CreateOrder(context.Background(), db.CreateOrderParams{ + Price: o.Price, + Amount: o.Amount, + Side: o.Side, + OrderType: o.Type, + Asset: o.Ticker, + CreatedBy: o.CreatedBy, + }) + + o.ID = o_id + + fmt.Printf("%s %s %s Order submitted %d @ %.2f ID: %s\n", + o.Side, o.Ticker, o.Type, o.Amount, o.Price, o_id.String()) ob.MatchOrders() return true } +// Withdraw implements the OrderBookInterface func (ob *OrderBook) Withdraw(o *order.Order) bool { - if o.Side == types.Buy { + if o.Side == db.OrderSideTypeBUY { if ob.Bids.Len() > 0 { return ob.removeFromHeap(ob.Bids, o.ID) } else { return false } - } else if o.Side == types.Sell { + } else if o.Side == db.OrderSideTypeSELL { if ob.Asks.Len() > 0 { return ob.removeFromHeap(ob.Asks, o.ID) } else { @@ -67,73 +88,82 @@ func (ob *OrderBook) MatchOrders() { bid := ob.Bids.Peek() ask := ob.Asks.Peek() - if bid.Type == types.Market { + if bid.Type == db.OrderTypeMARKET { bid.Price = ask.Price - } else if ask.Type == types.Market { + } else if ask.Type == db.OrderTypeMARKET { ask.Price = bid.Price } if bid.Price >= ask.Price { - tradeAmount := min(bid.Amount, ask.Amount) - - // Log the trade - fmt.Printf("Matched %.2f @ %.2f (Buy ID: %s, Sell ID: %s)\n", - tradeAmount, ask.Price, bid.ID, ask.ID) + tradeAmount := bid.Amount + if ask.Amount < bid.Amount { + tradeAmount = ask.Amount + } bid.Amount -= tradeAmount ask.Amount -= tradeAmount - bid.Status = types.PartiallyFilled - ask.Status = types.PartiallyFilled + bid.Status = db.OrderStatusTypePARTIALLYFILLED + ask.Status = db.OrderStatusTypePARTIALLYFILLED if bid.Amount == 0 { ob.Bids.Delete() - bid.Status = types.Filled + bid.Status = db.OrderStatusTypeFILLED } if ask.Amount == 0 { ob.Asks.Delete() - ask.Status = types.Filled + ask.Status = db.OrderStatusTypeFILLED } + + fmt.Print(bid.Ticker) + + event := events.TransactionEvent{ + Price: ask.Price, + Amount: tradeAmount, + BuyerOrder: bid, + SellerOrder: ask, + Asset: bid.Ticker, + Timestamp: time.Now(), + } + events.TransactionEventChan <- event + + // Log the trade + fmt.Printf("Matched %d @ %.2f (Buy ID: %s, Sell ID: %s)\n", + tradeAmount, ask.Price, bid.ID, ask.ID) + } else { break } } } -func min(a, b float64) float64 { - if a < b { - return a - } - return b -} - func handleOrderType(o *order.Order, ob *OrderBook) bool { switch o.Side { - case types.Buy: + case db.OrderSideTypeBUY: switch o.Type { - case types.Market: + case db.OrderTypeMARKET: if ob.Asks.Len() > 0 { o.Price = ob.Asks.Peek().Price ob.Bids.Insert(o) return true } return false - case types.Limit: + case db.OrderTypeLIMIT: ob.Bids.Insert(o) return true default: return false } - case types.Sell: + case db.OrderSideTypeSELL: switch o.Type { - case types.Market: + case db.OrderTypeMARKET: if ob.Bids.Len() > 0 { o.Price = ob.Bids.Peek().Price ob.Asks.Insert(o) return true } return false - case types.Limit: + case db.OrderTypeLIMIT: ob.Asks.Insert(o) return true default: diff --git a/internal/orderbook/orderbook_test.go b/internal/orderbook/orderbook_test.go index b7b1e70..594e747 100644 --- a/internal/orderbook/orderbook_test.go +++ b/internal/orderbook/orderbook_test.go @@ -1,9 +1,11 @@ package orderbook import ( + "exchange/internal/db" "exchange/internal/order" - "exchange/internal/types" "testing" + + "github.com/google/uuid" ) func TestNewOrderBook(t *testing.T) { @@ -24,22 +26,22 @@ func TestNewOrderBook(t *testing.T) { } } -func TestSubmitWrongOrder(t *testing.T) { - ob := NewOrderBook("LINK") +// func TestSubmitWrongOrder(t *testing.T) { +// ob := NewOrderBook("LINK") - new_bad_order := order.NewOrder(-10, 100, types.Sell, types.Limit, "LINK") - ob.Submit(new_bad_order) +// new_bad_order := order.NewOrder(-10, 100, db.OrderSideTypeSELL, db.OrderTypeLIMIT, "LINK", uuid.New()) +// ob.Submit(new_bad_order, nil) - if ob.Asks.Len() > 0 { - t.Errorf("Expected order not to be submitted to the orderbook.") - } -} +// if ob.Asks.Len() > 0 { +// t.Errorf("Expected order not to be submitted to the orderbook.") +// } +// } func TestSubmitLimitSellOrder(t *testing.T) { ob := NewOrderBook("LINK") - test_order := order.NewOrder(10, 100, types.Sell, types.Limit, "LINK") + test_order := order.NewOrder(60, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) - ob.Submit(test_order) + ob.Submit(test_order, nil) if ob.Bids.Len() > 0 { t.Error("SELL order was transmitted as a BUY order.") } @@ -51,9 +53,9 @@ func TestSubmitLimitSellOrder(t *testing.T) { func TestSubmitLimitBuyOrder(t *testing.T) { ob := NewOrderBook("LINK") - test_order := order.NewOrder(10, 100, types.Buy, types.Limit, "LINK") + test_order := order.NewOrder(60, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) - ob.Submit(test_order) + ob.Submit(test_order, nil) if ob.Asks.Len() > 0 { t.Error("BUY order was transmitted as a SELL order.") } @@ -65,11 +67,11 @@ func TestSubmitLimitBuyOrder(t *testing.T) { func TestSubmitMarketBuyOrder(t *testing.T) { ob := NewOrderBook("LINK") - test_order := order.NewOrder(10, 100, types.Buy, types.Limit, "LINK") - test_order_sell := order.NewOrder(10, 100, types.Sell, types.Market, "LINK") + test_order := order.NewOrder(10, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) + test_order_sell := order.NewOrder(10, 100, db.OrderSideTypeSELL, db.OrderTypeMARKET, "LINK", uuid.New()) - ob.Submit(test_order) - ob.Submit(test_order_sell) + ob.Submit(test_order, nil) + ob.Submit(test_order_sell, nil) if ob.Asks.Len() > 0 { t.Error("BUY order was transmitted as a SELL order.") @@ -86,9 +88,9 @@ func TestSubmitMarketBuyOrder(t *testing.T) { func TestSubmitMarketSellOrder(t *testing.T) { ob := NewOrderBook("LINK") - test_order_sell := order.NewOrder(10, 100, types.Sell, types.Limit, "LINK") + test_order_sell := order.NewOrder(10, 100, db.OrderSideTypeSELL, db.OrderTypeLIMIT, "LINK", uuid.New()) - ob.Submit(test_order_sell) + ob.Submit(test_order_sell, nil) if ob.Bids.Len() > 0 { t.Error("BUY order was transmitted as a SELL order.") @@ -101,9 +103,9 @@ func TestSubmitMarketSellOrder(t *testing.T) { func TestMarketOrderNoLiquidity(t *testing.T) { ob := NewOrderBook("LINK") - test_order := order.NewOrder(10, 100, types.Buy, types.Market, "LINK") + test_order := order.NewOrder(10, 100, db.OrderSideTypeBUY, db.OrderTypeMARKET, "LINK", uuid.New()) - result := ob.Submit(test_order) + result := ob.Submit(test_order, nil) if result == true { t.Error("Market order was processed despite insufficient liquidity.") } @@ -111,13 +113,13 @@ func TestMarketOrderNoLiquidity(t *testing.T) { func TestWithdrawOrder(t *testing.T) { ob := NewOrderBook("LINK") - test_order := order.NewOrder(12, 100, types.Buy, types.Limit, "LINK") - buy_order_to_withdraw := order.NewOrder(60, 100, types.Buy, types.Limit, "LINK") - sell_order_to_withdraw := order.NewOrder(120, 100, types.Sell, types.Limit, "LINK") + test_order := order.NewOrder(12, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) + buy_order_to_withdraw := order.NewOrder(60, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) + sell_order_to_withdraw := order.NewOrder(120, 100, db.OrderSideTypeSELL, db.OrderTypeLIMIT, "LINK", uuid.New()) - ob.Submit(test_order) - ob.Submit(buy_order_to_withdraw) - ob.Submit(sell_order_to_withdraw) + ob.Submit(test_order, nil) + ob.Submit(buy_order_to_withdraw, nil) + ob.Submit(sell_order_to_withdraw, nil) if ob.Bids.Len()+ob.Asks.Len() != 3 { t.Error("Not all orders were submitted.") @@ -137,7 +139,7 @@ func TestWithdrawOrder(t *testing.T) { func TestWithdrawEmptyOrderBook(t *testing.T) { ob := NewOrderBook("LINK") - order_to_withdraw := order.NewOrder(60, 100, types.Buy, types.Limit, "LINK") + order_to_withdraw := order.NewOrder(60, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) result := ob.Withdraw(order_to_withdraw) @@ -149,7 +151,7 @@ func TestWithdrawEmptyOrderBook(t *testing.T) { func TestWithdrawBadOrder(t *testing.T) { ob := NewOrderBook("LINK") - order_to_withdraw := order.NewOrder(60, 100, "SIDEWAYS", types.Limit, "LINK") + order_to_withdraw := order.NewOrder(60, 100, "SIDEWAYS", db.OrderTypeLIMIT, "LINK", uuid.New()) result := ob.Withdraw(order_to_withdraw) @@ -163,11 +165,11 @@ func TestOrdersMatched(t *testing.T) { ob := NewOrderBook("LINK") // Case 1: similar orders - order_1 := order.NewOrder(60, 100, types.Buy, types.Limit, "LINK") - order_2 := order.NewOrder(60, 100, types.Sell, types.Limit, "LINK") + order_1 := order.NewOrder(60, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) + order_2 := order.NewOrder(60, 100, db.OrderSideTypeSELL, db.OrderTypeLIMIT, "LINK", uuid.New()) - ob.Submit(order_1) - ob.Submit(order_2) + ob.Submit(order_1, nil) + ob.Submit(order_2, nil) if ob.Asks.Len()-ob.Bids.Len() != 0 { t.Errorf("Expected an empty order book.") @@ -178,27 +180,27 @@ func TestOrdersMatched(t *testing.T) { func TestOrdersPartiallyMatched(t *testing.T) { ob := NewOrderBook("LINK") - order_1 := order.NewOrder(60, 100, types.Buy, types.Limit, "LINK") - order_2 := order.NewOrder(60, 50, types.Sell, types.Limit, "LINK") - order_3 := order.NewOrder(60, 165, types.Sell, types.Market, "LINK") - order_4 := order.NewOrder(62, 150, types.Buy, types.Limit, "LINK") + order_1 := order.NewOrder(60, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) + order_2 := order.NewOrder(60, 50, db.OrderSideTypeSELL, db.OrderTypeLIMIT, "LINK", uuid.New()) + order_3 := order.NewOrder(60, 165, db.OrderSideTypeSELL, db.OrderTypeMARKET, "LINK", uuid.New()) + order_4 := order.NewOrder(62, 150, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) - ob.Submit(order_1) - ob.Submit(order_2) + ob.Submit(order_1, nil) + ob.Submit(order_2, nil) - if order_1.Status != types.PartiallyFilled { + if order_1.Status != db.OrderStatusTypePARTIALLYFILLED { t.Errorf("Expected status to be partially filled.") } - ob.Submit(order_3) + ob.Submit(order_3, nil) - if order_3.Status != types.PartiallyFilled { + if order_3.Status != db.OrderStatusTypePARTIALLYFILLED { t.Errorf("Expected status to be partially filled.") } - ob.Submit(order_4) + ob.Submit(order_4, nil) - if order_3.Status != types.Filled { + if order_3.Status != db.OrderStatusTypeFILLED { t.Errorf("Expected status to be fully filled.") } @@ -214,7 +216,7 @@ func TestOrdersPartiallyMatched(t *testing.T) { t.Errorf("Expected no remaining ask orders left.") } - if order_4.Status != types.PartiallyFilled { + if order_4.Status != db.OrderStatusTypePARTIALLYFILLED { t.Errorf("Expected status to be partially filled.") } @@ -222,9 +224,9 @@ func TestOrdersPartiallyMatched(t *testing.T) { func TestWrongOrderBook(t *testing.T) { ob := NewOrderBook("XRP") - order_1 := order.NewOrder(60, 100, types.Buy, types.Limit, "LINK") + order_1 := order.NewOrder(60, 100, db.OrderSideTypeBUY, db.OrderTypeLIMIT, "LINK", uuid.New()) - result := ob.Submit(order_1) + result := ob.Submit(order_1, nil) if result { t.Errorf("Order was submitted to the wrong orderbook.") } diff --git a/main.go b/main.go index 3147698..f320016 100644 --- a/main.go +++ b/main.go @@ -1,24 +1,72 @@ package main import ( + "context" + "database/sql" + "log" + "net/http" + + _ "github.com/lib/pq" + + "exchange/internal/api/handler" + "exchange/internal/api/router" + "exchange/internal/db" "exchange/internal/order" "exchange/internal/orderbook" - "exchange/internal/types" ) func main() { - testOrderbook := orderbook.NewOrderBook("LINK") + // @TODO: add connection string to .env + connStr := "postgresql://postgres:postgres@localhost:5432/testdb?sslmode=disable" + + database, err := sql.Open("postgres", connStr) + if err != nil { + log.Fatalf("Could not connect to DB: %v", err) + } + queries := db.New(database) + + // Load available assets on database + + assets, err := queries.GetAllAssets(context.Background()) + if err != nil { + log.Fatal("Could not load tradable assets:", err) + } + + validTickers := make(map[string]struct{}) + for _, asset := range assets { + if asset.IsTradable { + validTickers[asset.Ticker] = struct{}{} + } + } + + // Build orderbooks based on available assets + orderbooks := make(map[string]*orderbook.OrderBook) + for ticker := range validTickers { + orderbooks[ticker] = orderbook.NewOrderBook(ticker) + } + + // Set up API handler and router + h := &handler.Handler{Queries: queries, OrderBooks: orderbooks, ValidTickers: validTickers} + r := router.NewRouter(h) - new_order := order.NewOrder(12, 100, types.Buy, types.Market, "LINK") - testOrderbook.Submit(new_order) - testOrderbook.Submit(new_order) + // Get all submitted orders not filled and add them to the orderbook + orders, err := queries.GetSubmittedOrders(context.Background()) + if err != nil { + panic(err) + } - new_order = order.NewOrder(13, 100, types.Sell, types.Market, "LINK") - testOrderbook.Submit(new_order) + for _, o := range orders { + persisted_order := order.NewOrder(o.Price, o.Amount, o.Side, o.OrderType, o.Asset, o.CreatedBy) + persisted_order.ID = o.ID + if o.Side == "BUY" { + h.OrderBooks[o.Asset].Bids.Insert(persisted_order) + } else { + h.OrderBooks[o.Asset].Asks.Insert(persisted_order) + } + } - new_order = order.NewOrder(14, 100, types.Buy, types.Market, "LINK") - testOrderbook.Submit(new_order) + // TX processing workers + go handler.StartTransactionPersistenceWorker(queries) - new_order = order.NewOrder(12, 50, types.Sell, types.Market, "LINK") - testOrderbook.Submit(new_order) + http.ListenAndServe(":3000", r) } diff --git a/sqlc.yaml b/sqlc.yaml new file mode 100644 index 0000000..f65e0c3 --- /dev/null +++ b/sqlc.yaml @@ -0,0 +1,16 @@ +version: "2" +sql: + - engine: "postgresql" + queries: "internal/db/query/" + schema: "internal/db/migration/" + gen: + go: + package: "db" + out: "internal/db" + emit_json_tags: true + emit_interface: false + emit_empty_slices: true + emit_exported_queries: true + overrides: + - db_type: "uuid" + go_type: "github.com/google/uuid.UUID"