diff --git a/.gitignore b/.gitignore index 737bd40..0965051 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,7 @@ build # Ignore TODO TODO buffer + +migrations/data.up.sql + +*.diff diff --git a/go.mod b/go.mod index 2b15c11..d5ccc3a 100644 --- a/go.mod +++ b/go.mod @@ -11,11 +11,16 @@ require ( github.com/joho/godotenv v1.5.1 github.com/labstack/echo/v4 v4.13.3 github.com/redis/go-redis/v9 v9.12.1 + github.com/stretchr/testify v1.10.0 + github.com/tonkeeper/tongo v1.16.54 + github.com/xssnick/tonutils-go v1.15.5 ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/BurntSushi/toml v1.2.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -24,14 +29,18 @@ require ( github.com/labstack/gommon v0.4.2 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/oasisprotocol/curve25519-voi v0.0.0-20220328075252-7dd334e3daae // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect + github.com/snksoft/crc v1.1.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect - golang.org/x/crypto v0.37.0 // indirect - golang.org/x/net v0.38.0 // indirect - golang.org/x/sync v0.13.0 // indirect - golang.org/x/sys v0.32.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/exp v0.0.0-20230116083435-1de6713980de // indirect + golang.org/x/net v0.47.0 // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.8.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 // indirect diff --git a/go.sum b/go.sum index 7a7ae01..89b7caa 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -43,33 +45,43 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/oasisprotocol/curve25519-voi v0.0.0-20220328075252-7dd334e3daae h1:7smdlrfdcZic4VfsGKD2ulWL804a4GVphr4s7WZxGiY= +github.com/oasisprotocol/curve25519-voi v0.0.0-20220328075252-7dd334e3daae/go.mod h1:hVoHR2EVESiICEMbg137etN/Lx+lSrHPTD39Z/uE+2s= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.12.1 h1:k5iquqv27aBtnTm2tIkROUDp8JBXhXZIVu1InSgvovg= github.com/redis/go-redis/v9 v9.12.1/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/snksoft/crc v1.1.0 h1:HkLdI4taFlgGGG1KvsWMpz78PkOC9TkPVpTV/cuWn48= +github.com/snksoft/crc v1.1.0/go.mod h1:5/gUOsgAm7OmIhb6WJzw7w5g2zfJi4FrHYgGPdshE+A= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tonkeeper/tongo v1.16.54 h1:X8+VnC/gR/0+S1jlcY1hmvkL0wUrn5TEOyWL6uTHmd8= +github.com/tonkeeper/tongo v1.16.54/go.mod h1:MjgIgAytFarjCoVjMLjYEtpZNN1f2G/pnZhKjr28cWs= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= -golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= -golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= -golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +github.com/xssnick/tonutils-go v1.15.5 h1:yAcHnDaY5QW0aIQE47lT0PuDhhHYE+N+NyZssdPKR0s= +github.com/xssnick/tonutils-go v1.15.5/go.mod h1:3/B8mS5IWLTd1xbGbFbzRem55oz/Q86HG884bVsTqZ8= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/exp v0.0.0-20230116083435-1de6713980de h1:DBWn//IJw30uYCgERoxCg84hWtA97F4wMiKOIh00Uf0= +golang.org/x/exp v0.0.0-20230116083435-1de6713980de/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/internal/app/distributor/distributor.go b/internal/app/distributor/distributor.go new file mode 100644 index 0000000..e1aa0da --- /dev/null +++ b/internal/app/distributor/distributor.go @@ -0,0 +1,101 @@ +package distributor + +import ( + "context" + "errors" + "fmt" + "log/slog" + "math/big" + + "github.com/jackc/pgx/v5" + "github.com/voidcontests/api/internal/lib/crypto" + "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository" + "github.com/voidcontests/api/pkg/ton" + "github.com/xssnick/tonutils-go/address" + "github.com/xssnick/tonutils-go/tlb" +) + +func New(r *repository.Repository, tc *ton.Client, cipher crypto.Cipher) func(ctx context.Context) error { + return func(ctx context.Context) error { + contests, err := r.Contest.GetWithUndistributedAwards(ctx) + if err != nil { + return err + } + + for _, c := range contests { + err := distributeAwardForContest(ctx, r, tc, cipher, c) + if err != nil { + return err + } + } + return nil + } +} + +func distributeAwardForContest(ctx context.Context, r *repository.Repository, tc *ton.Client, cipher crypto.Cipher, c models.Contest) error { + w, err := r.Contest.GetWallet(ctx, *c.WalletID) + if err != nil { + return err + } + + decryptedMnemonic, err := cipher.Decrypt(w.MnemonicEncrypted) + if err != nil { + return fmt.Errorf("failed to decrypt mnemonic: %w", err) + } + + wallet, err := tc.WalletWithSeed(decryptedMnemonic) + if err != nil { + return err + } + + winnerID, err := r.Contest.GetWinnerID(ctx, c.ID) + if errors.Is(err, pgx.ErrNoRows) { + slog.Info("no users that submitted at least one ok solution", slog.Int("contest_id", c.ID)) + return nil + } + if err != nil { + return err + } + + winner, err := r.User.GetByID(ctx, winnerID) + if err != nil { + return err + } + + if winner.Address == "" { + return fmt.Errorf("user has no wallet") + } + + recepient, err := address.ParseAddr(winner.Address) + if err != nil { + return err + } + + nanos, err := tc.GetBalance(ctx, wallet.Address()) + if err != nil { + return err + } + + // keep 2% for paying gas + factor := 1 - 0.02 + amount := tlb.FromNanoTON(big.NewInt(int64(float64(nanos) * factor))) + tx, err := wallet.TransferTo(ctx, recepient, amount, fmt.Sprintf("contests.fckn.engineer: Prize for winning contest #%d", c.ID)) + if err != nil { + return err + } + + slog.Info("award distributed", slog.Int("contest_id", c.ID), slog.String("tx", tx)) + + paymentID, err := r.Payment.Create(ctx, tx, tc.GetAddress(wallet.Address()), tc.GetAddress(recepient), amount.Nano().Uint64(), false) + if err != nil { + return err + } + + err = r.Contest.SetDistributionPaymentID(ctx, c.ID, paymentID) + if err != nil { + return err + } + + return nil +} diff --git a/internal/app/handler/account.go b/internal/app/handler/account.go index 7978b9c..561c68b 100644 --- a/internal/app/handler/account.go +++ b/internal/app/handler/account.go @@ -2,25 +2,22 @@ package handler import ( "errors" - "fmt" "log/slog" "net/http" "strings" jwtgo "github.com/golang-jwt/jwt/v4" - "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" "github.com/voidcontests/api/internal/app/handler/dto/request" "github.com/voidcontests/api/internal/app/handler/dto/response" - "github.com/voidcontests/api/internal/hasher" + "github.com/voidcontests/api/internal/app/service" "github.com/voidcontests/api/internal/jwt" - "github.com/voidcontests/api/internal/lib/logger/sl" + "github.com/voidcontests/api/internal/storage/models" "github.com/voidcontests/api/pkg/requestid" "github.com/voidcontests/api/pkg/validate" ) func (h *Handler) CreateAccount(c echo.Context) error { - op := "handler.CreateAccount" ctx := c.Request().Context() var body request.CreateAccount @@ -28,28 +25,20 @@ func (h *Handler) CreateAccount(c echo.Context) error { return Error(http.StatusBadRequest, "invalid body: missing required fields") } - exists, err := h.repo.User.Exists(ctx, body.Username) + id, err := h.service.Account.CreateAccount(ctx, body.Username, body.Password) if err != nil { - return fmt.Errorf("%s: can't verify that user exists or not: %v", op, err) - } - - if exists { - return Error(http.StatusConflict, "user already exists") - } - - passwordHash := hasher.Sha256String([]byte(body.Password), []byte(h.config.Security.Salt)) - user, err := h.repo.User.Create(ctx, body.Username, passwordHash) - if err != nil { - return fmt.Errorf("%s: failed to create user: %v", op, err) + if errors.Is(err, service.ErrUserAlreadyExists) { + return Error(http.StatusConflict, "user already exists") + } + return err } return c.JSON(http.StatusCreated, response.ID{ - ID: user.ID, + ID: id, }) } func (h *Handler) CreateSession(c echo.Context) error { - op := "handler.CreateSession" ctx := c.Request().Context() var body request.CreateSession @@ -57,18 +46,12 @@ func (h *Handler) CreateSession(c echo.Context) error { return Error(http.StatusBadRequest, "invalid body: missing required fields") } - passwordHash := hasher.Sha256String([]byte(body.Password), []byte(h.config.Security.Salt)) - user, err := h.repo.User.GetByCredentials(ctx, body.Username, passwordHash) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusUnauthorized, "user not found") - } - if err != nil { - return fmt.Errorf("%s: can't create user: %v", op, err) - } - - token, err := jwt.GenerateToken(user.ID, h.config.Security.SignatureKey) + token, err := h.service.Account.CreateSession(ctx, body.Username, body.Password) if err != nil { - return fmt.Errorf("%s: can't generate token: %v", op, err) + if errors.Is(err, service.ErrInvalidCredentials) { + return Error(http.StatusUnauthorized, "invalid credentials") + } + return err } return c.JSON(http.StatusCreated, response.Token{ @@ -77,32 +60,64 @@ func (h *Handler) CreateSession(c echo.Context) error { } func (h *Handler) GetAccount(c echo.Context) error { - op := "handler.GetAccount" ctx := c.Request().Context() claims, _ := ExtractClaims(c) - user, err := h.repo.User.GetByID(ctx, claims.UserID) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusUnauthorized, "invalid or expired token") - } + accountInfo, err := h.service.Account.GetAccount(ctx, claims.UserID) if err != nil { - return fmt.Errorf("%s: can't get user: %v", op, err) + if errors.Is(err, service.ErrInvalidToken) { + return Error(http.StatusUnauthorized, "invalid or expired token") + } + return err + } + + return c.JSON(http.StatusOK, response.Account{ + ID: accountInfo.User.ID, + Username: accountInfo.User.Username, + Address: accountInfo.User.Address, + Role: response.Role{ + Name: accountInfo.Role.Name, + CreatedProblemsLimit: accountInfo.Role.CreatedProblemsLimit, + CreatedContestsLimit: accountInfo.Role.CreatedContestsLimit, + }, + }) +} + +func (h *Handler) UpdateAccount(c echo.Context) error { + ctx := c.Request().Context() + + claims, _ := ExtractClaims(c) + + var body request.UpdateAccount + if err := validate.Bind(c, &body); err != nil { + return Error(http.StatusBadRequest, "invalid body: missing required fields") } - role, err := h.repo.User.GetRole(ctx, claims.UserID) + if body.Username == nil && body.Address == nil { + return Error(http.StatusBadRequest, "at least one field must be provided") + } + + params := models.UpdateUserParams{ + Username: body.Username, + Address: body.Address, + } + + user, err := h.service.Account.UpdateAccount(ctx, claims.UserID, params) if err != nil { - return fmt.Errorf("%s: can't get role: %v", op, err) + if errors.Is(err, service.ErrUserAlreadyExists) { + return Error(http.StatusConflict, "username already taken") + } + if errors.Is(err, service.ErrInvalidToken) { + return Error(http.StatusUnauthorized, "invalid or expired token") + } + return err } - return c.JSON(http.StatusOK, response.Account{ + return c.JSON(http.StatusOK, response.User{ ID: user.ID, Username: user.Username, - Role: response.Role{ - Name: role.Name, - CreatedProblemsLimit: role.CreatedProblemsLimit, - CreatedContestsLimit: role.CreatedContestsLimit, - }, + Address: user.Address, }) } @@ -121,7 +136,6 @@ func (h *Handler) UserIdentity(skiperr bool) echo.MiddlewareFunc { authHeader := c.Request().Header.Get(echo.HeaderAuthorization) if authHeader == "" { - log.Debug("auth header is empty, skipping check") if skiperr { return next(c) } else { @@ -149,7 +163,6 @@ func (h *Handler) UserIdentity(skiperr bool) echo.MiddlewareFunc { }) if err != nil { - log.Debug("token parsing failed", sl.Err(err)) if skiperr { return next(c) } else { @@ -158,7 +171,6 @@ func (h *Handler) UserIdentity(skiperr bool) echo.MiddlewareFunc { } if !token.Valid { - log.Debug("invalid token") if skiperr { return next(c) } else { @@ -168,7 +180,6 @@ func (h *Handler) UserIdentity(skiperr bool) echo.MiddlewareFunc { claims, ok := token.Claims.(*jwt.CustomClaims) if !ok { - log.Debug("invalid token claims") if skiperr { return next(c) } else { diff --git a/internal/app/handler/contest.go b/internal/app/handler/contest.go index aadfcbc..fe852db 100644 --- a/internal/app/handler/contest.go +++ b/internal/app/handler/contest.go @@ -2,61 +2,58 @@ package handler import ( "errors" - "fmt" "net/http" - "time" - "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" "github.com/voidcontests/api/internal/app/handler/dto/request" "github.com/voidcontests/api/internal/app/handler/dto/response" + "github.com/voidcontests/api/internal/app/service" "github.com/voidcontests/api/internal/storage/models" "github.com/voidcontests/api/pkg/validate" ) func (h *Handler) CreateContest(c echo.Context) error { - op := "handler.CreateContest" ctx := c.Request().Context() claims, _ := ExtractClaims(c) - var body request.CreateContestRequest + var body request.CreateContest if err := validate.Bind(c, &body); err != nil { return Error(http.StatusBadRequest, "invalid body: missing required fields") } - userrole, err := h.repo.User.GetRole(ctx, claims.UserID) + id, err := h.service.Contest.CreateContest(ctx, service.CreateContestParams{ + UserID: claims.UserID, + Title: body.Title, + Description: body.Description, + AwardType: body.AwardType, + EntryPriceTonNanos: body.EntryPriceTonNanos, + StartTime: body.StartTime, + EndTime: body.EndTime, + DurationMins: body.DurationMins, + MaxEntries: body.MaxEntries, + AllowLateJoin: body.AllowLateJoin, + ProblemIDs: body.ProblemsIDs, + }) if err != nil { - return fmt.Errorf("%s: can't get role: %v", op, err) - } - - if userrole.Name == models.RoleBanned { - return Error(http.StatusForbidden, "you are banned from creating contests") - } - - if userrole.Name == models.RoleLimited { - cscount, err := h.repo.User.GetCreatedContestsCount(ctx, claims.UserID) - if err != nil { - return fmt.Errorf("%s: can't get created contests count: %v", op, err) - } - - if cscount >= int(userrole.CreatedContestsLimit) { + switch { + case errors.Is(err, service.ErrUserBanned): + return Error(http.StatusForbidden, "you are banned from creating contests") + case errors.Is(err, service.ErrContestsLimitExceeded): return Error(http.StatusForbidden, "contests limit exceeded") + case errors.Is(err, service.ErrInvalidContestTiming): + return Error(http.StatusBadRequest, "invalid contest timing: check start time, end time, and duration") + default: + return err } } - contestID, err := h.repo.Contest.CreateWithProblemIDs(ctx, claims.UserID, body.Title, body.Description, body.StartTime, body.EndTime, body.DurationMins, body.MaxEntries, body.AllowLateJoin, body.ProblemsIDs) - if err != nil { - return fmt.Errorf("%s: can't create contest: %v", op, err) - } - return c.JSON(http.StatusCreated, response.ID{ - ID: contestID, + ID: id, }) } func (h *Handler) GetContestByID(c echo.Context) error { - op := "handler.GetContestByID" ctx := c.Request().Context() claims, authenticated := ExtractClaims(c) @@ -66,141 +63,143 @@ func (h *Handler) GetContestByID(c echo.Context) error { return Error(http.StatusBadRequest, "contest ID should be an integer") } - contest, err := h.repo.Contest.GetByID(ctx, int32(contestID)) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "contest not found") - } + details, err := h.service.Contest.GetContestByID(ctx, contestID, claims.UserID, authenticated) if err != nil { - return fmt.Errorf("%s: can't get contest: %v", op, err) - } - - if contest.EndTime.Before(time.Now()) { - if (authenticated && claims.UserID != contest.CreatorID) || !authenticated { + switch { + case errors.Is(err, service.ErrContestNotFound): return Error(http.StatusNotFound, "contest not found") + case errors.Is(err, service.ErrContestFinished): + return Error(http.StatusNotFound, "contest not found") + default: + return err } } - problems, err := h.repo.Contest.GetProblemset(ctx, contest.ID) - if err != nil { - return fmt.Errorf("%s: can't get problemset: %v", op, err) + contest := details.Contest + n := len(details.Problems) + awards := response.Awards{ + Kind: contest.AwardType, + Nanocoins: details.PrizeNanosTON, + IsDistributed: contest.DistributionPaymentID != nil, + } + if details.DistributionPayment != nil { + awards.DistributionTxHash = details.DistributionPayment.TxHash } - n := len(problems) cdetailed := response.ContestDetailed{ ID: contest.ID, Title: contest.Title, Description: contest.Description, - Problems: make([]response.ContestProblemListItem, n, n), Creator: response.User{ ID: contest.CreatorID, Username: contest.CreatorUsername, + Address: contest.CreatorAddress, }, - Participants: contest.Participants, - StartTime: contest.StartTime, - EndTime: contest.EndTime, - DurationMins: contest.DurationMins, - MaxEntries: contest.MaxEntries, - AllowLateJoin: contest.AllowLateJoin, - CreatedAt: contest.CreatedAt, + Address: details.WalletAddress, + StartTime: contest.StartTime, + EndTime: contest.EndTime, + DurationMins: contest.DurationMins, + Participants: contest.ParticipantsCount, + MaxEntries: contest.MaxEntries, + IsRegistrationOpen: details.IsRegistrationOpen, + EntryPriceTonNanos: contest.EntryPriceTonNanos, + Awards: awards, + Problems: make([]response.ContestProblemListItem, n, n), + CreatedAt: contest.CreatedAt, + } + + if details.EntryDetails != nil { + entryDetails := details.EntryDetails + contestEntry := response.Entry{ + IsAdmitted: entryDetails.IsAdmitted, + SubmissionDeadline: entryDetails.SubmissionDeadline, + Message: entryDetails.Message, + IsPaid: entryDetails.Entry.PaymentID != nil, + CreatedAt: entryDetails.Entry.CreatedAt, + } + + if entryDetails.Payment != nil { + contestEntry.Payment = &response.PaymentDetails{ + TxHash: entryDetails.Payment.TxHash, + CreatedAt: entryDetails.Payment.CreatedAt, + } + } + + cdetailed.Entry = &contestEntry } for i := range n { + p := details.Problems[i] cdetailed.Problems[i] = response.ContestProblemListItem{ - ID: problems[i].ID, - Charcode: problems[i].Charcode, + ID: p.ID, + Charcode: p.Charcode, Writer: response.User{ - ID: problems[i].WriterID, - Username: problems[i].WriterUsername, + ID: p.WriterID, + Username: p.WriterUsername, + Address: p.WriterAddress, }, - Title: problems[i].Title, - Difficulty: problems[i].Difficulty, - TimeLimitMS: problems[i].TimeLimitMS, - MemoryLimitMB: problems[i].MemoryLimitMB, - Checker: problems[i].Checker, - CreatedAt: problems[i].CreatedAt, + Title: p.Title, + Difficulty: p.Difficulty, + TimeLimitMS: p.TimeLimitMS, + MemoryLimitMB: p.MemoryLimitMB, + Checker: p.Checker, + CreatedAt: p.CreatedAt, + Status: details.ProblemStatuses[p.ID], } } - // NOTE: Return contest without problem submissions - // statuses if user is not authenticated - if !authenticated { - return c.JSON(http.StatusOK, cdetailed) - } - - entry, err := h.repo.Entry.Get(ctx, contest.ID, claims.UserID) - if err != nil && !errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("%s: can't get entry: %v", op, err) - } - if errors.Is(err, pgx.ErrNoRows) { - return c.JSON(http.StatusOK, cdetailed) - } - - cdetailed.IsParticipant = true - - _, deadline := AllowSubmitAt(contest, entry) - if contest.StartTime.Before(time.Now()) { - cdetailed.SubmissionDeadline = &deadline - } - - statuses, err := h.repo.Submission.GetProblemStatuses(ctx, entry.ID) - if err != nil { - return fmt.Errorf("%s: can't get submissions: %v", op, err) - } - - for i := range n { - problemID := cdetailed.Problems[i].ID - cdetailed.Problems[i].Status = statuses[problemID] - } - return c.JSON(http.StatusOK, cdetailed) } func (h *Handler) GetCreatedContests(c echo.Context) error { - op := "handler.GetCreatedContests" ctx := c.Request().Context() claims, _ := ExtractClaims(c) limit, ok := ExtractQueryParamInt(c, "limit") - if !ok { + if !ok || limit < 0 { limit = 10 } offset, ok := ExtractQueryParamInt(c, "offset") - if !ok { + if !ok || offset < 0 { offset = 0 } - contests, total, err := h.repo.Contest.GetWithCreatorID(ctx, claims.UserID, limit, offset) + result, err := h.service.Contest.ListCreatedContests(ctx, claims.UserID, limit, offset) if err != nil { - return fmt.Errorf("%s: can't get created contests: %v", op, err) + return err } items := make([]response.ContestListItem, 0) - for _, contest := range contests { + for _, contest := range result.Contests { item := response.ContestListItem{ ID: contest.ID, Creator: response.User{ ID: contest.CreatorID, Username: contest.CreatorUsername, + Address: contest.CreatorAddress, }, - Title: contest.Title, - StartTime: contest.StartTime, - EndTime: contest.EndTime, - DurationMins: contest.DurationMins, - MaxEntries: contest.MaxEntries, - Participants: contest.Participants, - CreatedAt: contest.CreatedAt, + Title: contest.Title, + AwardType: contest.AwardType, + EntryPriceTonNanos: contest.EntryPriceTonNanos, + StartTime: contest.StartTime, + EndTime: contest.EndTime, + DurationMins: contest.DurationMins, + MaxEntries: contest.MaxEntries, + Participants: contest.ParticipantsCount, + AwardDistributed: contest.DistributionPaymentID != nil, + CreatedAt: contest.CreatedAt, } items = append(items, item) } return c.JSON(http.StatusOK, response.Pagination[response.ContestListItem]{ Meta: response.Meta{ - Total: total, + Total: result.Total, Limit: limit, Offset: offset, - HasNext: offset+limit < total, + HasNext: offset+limit < result.Total, HasPrev: offset > 0, }, Items: items, @@ -208,7 +207,6 @@ func (h *Handler) GetCreatedContests(c echo.Context) error { } func (h *Handler) GetContests(c echo.Context) error { - op := "handler.GetContests" ctx := c.Request().Context() limit, ok := ExtractQueryParamInt(c, "limit") @@ -221,44 +219,60 @@ func (h *Handler) GetContests(c echo.Context) error { offset = 0 } - contests, total, err := h.repo.Contest.ListAll(ctx, limit, offset) + filters := models.ContestFilters{} + + if creatorID, ok := ExtractQueryParamInt(c, "creator_id"); ok { + if creatorID <= 0 { + return Error(http.StatusBadRequest, "creator_id should be a valid integer, greater than 0") + } + filters.CreatorID = creatorID + } + + if title := c.QueryParam("title"); title != "" { + filters.Title = title + } + + result, err := h.service.Contest.ListAllContests(ctx, limit, offset, filters) if err != nil { - return fmt.Errorf("%s: can't get contests: %w", op, err) + return err } items := make([]response.ContestListItem, 0) - for _, contest := range contests { + for _, contest := range result.Contests { item := response.ContestListItem{ ID: contest.ID, Creator: response.User{ ID: contest.CreatorID, Username: contest.CreatorUsername, + Address: contest.CreatorAddress, }, - Title: contest.Title, - StartTime: contest.StartTime, - EndTime: contest.EndTime, - DurationMins: contest.DurationMins, - MaxEntries: contest.MaxEntries, - Participants: contest.Participants, - CreatedAt: contest.CreatedAt, + Title: contest.Title, + AwardType: contest.AwardType, + EntryPriceTonNanos: contest.EntryPriceTonNanos, + StartTime: contest.StartTime, + EndTime: contest.EndTime, + DurationMins: contest.DurationMins, + MaxEntries: contest.MaxEntries, + Participants: contest.ParticipantsCount, + AwardDistributed: contest.DistributionPaymentID != nil, + CreatedAt: contest.CreatedAt, } items = append(items, item) } return c.JSON(http.StatusOK, response.Pagination[response.ContestListItem]{ Meta: response.Meta{ - Total: total, + Total: result.Total, Limit: limit, Offset: offset, - HasNext: offset+limit < total, + HasNext: offset+limit < result.Total, HasPrev: offset > 0, }, Items: items, }) } -func (h *Handler) GetLeaderboard(c echo.Context) error { - op := "handler.GetLeaderboard" +func (h *Handler) GetScores(c echo.Context) error { ctx := c.Request().Context() contestID, ok := ExtractParamInt(c, "cid") @@ -267,36 +281,31 @@ func (h *Handler) GetLeaderboard(c echo.Context) error { } limit, ok := ExtractQueryParamInt(c, "limit") - if !ok { + if !ok || limit < 0 { limit = 50 } offset, ok := ExtractQueryParamInt(c, "offset") - if !ok { + if !ok || offset < 0 { offset = 0 } - _, err := h.repo.Contest.GetByID(ctx, int32(contestID)) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "contest not found") - } - if err != nil { - return fmt.Errorf("%s: can't get contest: %v", op, err) - } - - leaderboard, total, err := h.repo.Contest.GetLeaderboard(ctx, contestID, limit, offset) + result, err := h.service.Contest.GetScores(ctx, contestID, limit, offset) if err != nil { - return fmt.Errorf("%s: can't get leaderboard: %v", op, err) + if errors.Is(err, service.ErrContestNotFound) { + return Error(http.StatusNotFound, "contest not found") + } + return err } - return c.JSON(http.StatusOK, response.Pagination[models.LeaderboardEntry]{ + return c.JSON(http.StatusOK, response.Pagination[models.ScoresEntry]{ Meta: response.Meta{ - Total: total, + Total: result.Total, Limit: limit, Offset: offset, - HasNext: offset+limit < total, + HasNext: offset+limit < result.Total, HasPrev: offset > 0, }, - Items: leaderboard, + Items: result.Scores, }) } diff --git a/internal/app/handler/dto/request/request.go b/internal/app/handler/dto/request/request.go index c0bbd7a..517eac0 100644 --- a/internal/app/handler/dto/request/request.go +++ b/internal/app/handler/dto/request/request.go @@ -16,18 +16,25 @@ type CreateSession struct { Password string `json:"password" required:"true"` } -type CreateContestRequest struct { - Title string `json:"title" required:"true"` - Description string `json:"description"` - ProblemsIDs []int32 `json:"problems_ids" required:"true"` - StartTime time.Time `json:"start_time" required:"true"` - EndTime time.Time `json:"end_time" required:"true"` - DurationMins int32 `json:"duration_mins" requried:"true"` - MaxEntries int32 `json:"max_entries"` - AllowLateJoin bool `json:"allow_late_join"` -} - -type CreateProblemRequest struct { +type UpdateAccount struct { + Username *string `json:"username"` + Address *string `json:"address"` +} + +type CreateContest struct { + Title string `json:"title" required:"true"` + Description string `json:"description"` + AwardType string `json:"award_type"` + EntryPriceTonNanos uint64 `json:"entry_price_ton_nanos"` + ProblemsIDs []int `json:"problems_ids" required:"true"` + StartTime time.Time `json:"start_time" required:"true"` + EndTime time.Time `json:"end_time" required:"true"` + DurationMins int `json:"duration_mins" requried:"true"` + MaxEntries int `json:"max_entries"` + AllowLateJoin bool `json:"allow_late_join"` +} + +type CreateProblem struct { Title string `json:"title" required:"true"` Statement string `json:"statement" required:"true"` Difficulty string `json:"difficulty" required:"true"` @@ -37,7 +44,26 @@ type CreateProblemRequest struct { TestCases []models.TestCaseDTO `json:"test_cases"` } -type CreateSubmissionRequest struct { +type CreateSubmission struct { Code string `json:"code"` Language string `json:"language"` } + +type TonProof struct { + Address string `json:"address"` + Network string `json:"network"` + Proof ProofData `json:"proof"` +} + +type ProofData struct { + Timestamp int64 `json:"timestamp"` + Domain Domain `json:"domain"` + Signature string `json:"signature"` + Payload string `json:"payload"` + StateInit string `json:"state_init"` +} + +type Domain struct { + LengthBytes int `json:"lengthBytes"` + Value string `json:"value"` +} diff --git a/internal/app/handler/dto/response/response.go b/internal/app/handler/dto/response/response.go index be5ab0c..6ed3d36 100644 --- a/internal/app/handler/dto/response/response.go +++ b/internal/app/handler/dto/response/response.go @@ -18,7 +18,7 @@ type Meta struct { } type ID struct { - ID int32 `json:"id"` + ID int `json:"id"` } type Token struct { @@ -26,54 +26,86 @@ type Token struct { } type Account struct { - ID int32 `json:"id"` + ID int `json:"id"` Username string `json:"username"` + Address string `json:"address,omitempty"` Role Role `json:"role"` } type Role struct { Name string `json:"name"` - CreatedProblemsLimit int32 `json:"created_problems_limit"` - CreatedContestsLimit int32 `json:"created_contests_limit"` + CreatedProblemsLimit int `json:"created_problems_limit"` + CreatedContestsLimit int `json:"created_contests_limit"` } type User struct { - ID int32 `json:"id"` + ID int `json:"id"` Username string `json:"username"` + Address string `json:"address,omitempty"` } type ContestDetailed struct { - ID int32 `json:"id"` - Creator User `json:"creator"` + ID int `json:"id"` Title string `json:"title"` Description string `json:"description"` + Creator User `json:"creator"` + Address string `json:"address,omitempty"` StartTime time.Time `json:"start_time"` EndTime time.Time `json:"end_time"` - DurationMins int32 `json:"duration_mins"` - MaxEntries int32 `json:"max_entries,omitempty"` - Participants int32 `json:"participants"` - AllowLateJoin bool `json:"allow_late_join"` - IsParticipant bool `json:"is_participant,omitempty"` - SubmissionDeadline *time.Time `json:"submission_deadline,omitempty"` + DurationMins int `json:"duration_mins"` + Participants int `json:"participants"` + MaxEntries int `json:"max_entries,omitempty"` + IsRegistrationOpen bool `json:"is_registration_open"` + EntryPriceTonNanos uint64 `json:"entry_price_ton_nanos"` + Entry *Entry `json:"entry,omitempty"` + Awards Awards `json:"awards"` Problems []ContestProblemListItem `json:"problems"` CreatedAt time.Time `json:"created_at"` } +type Awards struct { + Kind string `json:"kind"` + Nanocoins uint64 `json:"nanocoins"` + IsDistributed bool `json:"is_distributed"` + DistributionTxHash string `json:"distribution_tx_hash,omitempty"` +} + +type Entry struct { + IsAdmitted bool `json:"is_admitted"` + SubmissionDeadline *time.Time `json:"submission_deadline,omitempty"` + Message string `json:"message,omitempty"` + IsPaid bool `json:"is_paid"` + Payment *PaymentDetails `json:"payment,omitempty"` + CreatedAt time.Time `json:"created_at"` +} + +type PaymentDetails struct { + TxHash string `json:"tx_hash"` + CreatedAt time.Time `json:"created_at"` +} + +type Prizes struct { + Nanos uint64 `json:"ton_nanos"` +} + type ContestListItem struct { - ID int32 `json:"id"` - Creator User `json:"creator"` - Title string `json:"title"` - StartTime time.Time `json:"start_time"` - EndTime time.Time `json:"end_time"` - DurationMins int32 `json:"duration_mins"` - MaxEntries int32 `json:"max_entries,omitempty"` - Participants int32 `json:"participants"` - CreatedAt time.Time `json:"created_at"` + ID int `json:"id"` + Creator User `json:"creator"` + Title string `json:"title"` + AwardType string `json:"award_type"` + EntryPriceTonNanos uint64 `json:"entry_price_ton_nanos"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + DurationMins int `json:"duration_mins"` + MaxEntries int `json:"max_entries,omitempty"` + Participants int `json:"participants"` + AwardDistributed bool `json:"award_distributed"` + CreatedAt time.Time `json:"created_at"` } type Submission struct { - ID int32 `json:"id"` - ProblemID int32 `json:"problem_id"` + ID int `json:"id"` + ProblemID int `json:"problem_id"` Status string `json:"status"` Verdict string `json:"verdict"` Code string `json:"code,omitempty"` @@ -83,11 +115,11 @@ type Submission struct { } type TestingReport struct { - ID int32 `json:"id"` - PassedTestsCount int32 `json:"passed_tests_count"` - TotalTestsCount int32 `json:"total_tests_count"` + ID int `json:"id"` + PassedTestsCount int `json:"passed_tests_count"` + TotalTestsCount int `json:"total_tests_count"` FailedTest *Test `json:"failed_test,omitempty"` - Stderr string `json:"stderr"` + Stderr string `json:"stderr,omitemtpy"` CreatedAt time.Time `json:"created_at"` } @@ -98,55 +130,55 @@ type Test struct { } type ContestProblemDetailed struct { - ID int32 `json:"id"` + ID int `json:"id"` Charcode string `json:"charcode"` - ContestID int32 `json:"contest_id"` + ContestID int `json:"contest_id"` Writer User `json:"writer"` Title string `json:"title"` Statement string `json:"statement"` Examples []TC `json:"examples,omitempty"` Difficulty string `json:"difficulty"` Status string `json:"status,omitempty"` - TimeLimitMS int32 `json:"time_limit_ms"` - MemoryLimitMB int32 `json:"memory_limit_mb"` + TimeLimitMS int `json:"time_limit_ms"` + MemoryLimitMB int `json:"memory_limit_mb"` Checker string `json:"checker"` SubmissionDeadline *time.Time `json:"submission_deadline,omitempty"` CreatedAt time.Time `json:"created_at"` } type ContestProblemListItem struct { - ID int32 `json:"id"` + ID int `json:"id"` Charcode string `json:"charcode"` Writer User `json:"writer"` Title string `json:"title"` Difficulty string `json:"difficulty"` Status string `json:"status,omitempty"` - TimeLimitMS int32 `json:"time_limit_ms"` - MemoryLimitMB int32 `json:"memory_limit_mb"` + TimeLimitMS int `json:"time_limit_ms"` + MemoryLimitMB int `json:"memory_limit_mb"` Checker string `json:"checker"` CreatedAt time.Time `json:"created_at"` } type ProblemDetailed struct { - ID int32 `json:"id"` + ID int `json:"id"` Writer User `json:"writer"` Title string `json:"title"` Statement string `json:"statement"` Examples []TC `json:"examples,omitempty"` Difficulty string `json:"difficulty"` - TimeLimitMS int32 `json:"time_limit_ms"` - MemoryLimitMB int32 `json:"memory_limit_mb"` + TimeLimitMS int `json:"time_limit_ms"` + MemoryLimitMB int `json:"memory_limit_mb"` Checker string `json:"checker"` CreatedAt time.Time `json:"created_at"` } type ProblemListItem struct { - ID int32 `json:"id"` + ID int `json:"id"` Writer User `json:"writer"` Title string `json:"title"` Difficulty string `json:"difficulty"` - TimeLimitMS int32 `json:"time_limit_ms"` - MemoryLimitMB int32 `json:"memory_limit_mb"` + TimeLimitMS int `json:"time_limit_ms"` + MemoryLimitMB int `json:"memory_limit_mb"` Checker string `json:"checker"` CreatedAt time.Time `json:"created_at"` } diff --git a/internal/app/handler/entry.go b/internal/app/handler/entry.go index 1f0c46d..95e0825 100644 --- a/internal/app/handler/entry.go +++ b/internal/app/handler/entry.go @@ -2,16 +2,13 @@ package handler import ( "errors" - "fmt" "net/http" - "time" - "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" + "github.com/voidcontests/api/internal/app/service" ) func (h *Handler) CreateEntry(c echo.Context) error { - op := "handler.CreateEntry" ctx := c.Request().Context() claims, _ := ExtractClaims(c) @@ -21,40 +18,21 @@ func (h *Handler) CreateEntry(c echo.Context) error { return Error(http.StatusBadRequest, "contest ID should be an integer") } - contest, err := h.repo.Contest.GetByID(ctx, int32(contestID)) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "contest not found") - } - if err != nil { - return fmt.Errorf("%s: can't get contest: %v", op, err) - } - - entries, err := h.repo.Contest.GetEntriesCount(ctx, int32(contestID)) + _, err := h.service.Contest.CreateEntry(ctx, contestID, claims.UserID) if err != nil { - return fmt.Errorf("%s: can't get entries: %v", op, err) - } - - if contest.MaxEntries != 0 && entries >= contest.MaxEntries { - return Error(http.StatusConflict, "max slots limit reached") - } - - // NOTE: disallow join if: contest already finished or (already started and no late joins) - if contest.EndTime.Before(time.Now()) || (contest.StartTime.Before(time.Now()) && !contest.AllowLateJoin) { - return Error(http.StatusForbidden, "application time is over") - } - - _, err = h.repo.Entry.Get(ctx, int32(contestID), claims.UserID) - if errors.Is(err, pgx.ErrNoRows) { - _, err = h.repo.Entry.Create(ctx, int32(contestID), claims.UserID) - if err != nil { - return fmt.Errorf("%s: can't create entry: %v", op, err) + switch { + case errors.Is(err, service.ErrContestNotFound): + return Error(http.StatusNotFound, "contest not found") + case errors.Is(err, service.ErrMaxSlotsReached): + return Error(http.StatusConflict, "max slots limit reached") + case errors.Is(err, service.ErrApplicationTimeOver): + return Error(http.StatusForbidden, "application time is over") + case errors.Is(err, service.ErrEntryAlreadyExists): + return Error(http.StatusConflict, "user already has entry for this contest") + default: + return err } - - return c.NoContent(http.StatusCreated) - } - if err != nil { - return fmt.Errorf("%s: can't get entry: %v", op, err) } - return Error(http.StatusConflict, "user already has entry for this contest") + return c.NoContent(http.StatusCreated) } diff --git a/internal/app/handler/error.go b/internal/app/handler/error.go new file mode 100644 index 0000000..ff97c56 --- /dev/null +++ b/internal/app/handler/error.go @@ -0,0 +1,36 @@ +package handler + +import ( + "log/slog" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/voidcontests/api/internal/lib/logger/sl" + "github.com/voidcontests/api/pkg/requestid" +) + +func ErorHTTP(err error, c echo.Context) { + requestID := requestid.Get(c) + if he, ok := err.(*echo.HTTPError); ok && (he.Code == http.StatusNotFound || he.Code == http.StatusMethodNotAllowed) { + c.JSON(http.StatusNotFound, map[string]string{ + "message": "resource not found", + "request_id": requestID, + }) + return + } + + if ae, ok := err.(*APIError); ok { + slog.Debug("responded with API error", sl.Err(err), slog.String("request_id", requestid.Get(c))) + c.JSON(ae.Status, map[string]any{ + "message": ae.Message, + "request_id": requestID, + }) + return + } + + slog.Error("something went wrong", sl.Err(err), slog.String("request_id", requestid.Get(c))) + c.JSON(http.StatusInternalServerError, map[string]any{ + "message": "internal server error", + "request_id": requestID, + }) +} diff --git a/internal/app/handler/handler.go b/internal/app/handler/handler.go index 979127f..2870c50 100644 --- a/internal/app/handler/handler.go +++ b/internal/app/handler/handler.go @@ -5,23 +5,28 @@ import ( "strconv" "github.com/labstack/echo/v4" + "github.com/voidcontests/api/internal/app/service" "github.com/voidcontests/api/internal/config" "github.com/voidcontests/api/internal/jwt" + "github.com/voidcontests/api/internal/lib/crypto" "github.com/voidcontests/api/internal/storage/broker" "github.com/voidcontests/api/internal/storage/repository" + "github.com/voidcontests/api/pkg/ton" ) type Handler struct { - config *config.Config - repo *repository.Repository - broker broker.Broker + config *config.Config + repo *repository.Repository + broker broker.Broker + service *service.Service } -func New(c *config.Config, r *repository.Repository, b broker.Broker) *Handler { +func New(c *config.Config, r *repository.Repository, b broker.Broker, tc *ton.Client, cipher crypto.Cipher) *Handler { return &Handler{ - config: c, - repo: r, - broker: b, + config: c, + repo: r, + broker: b, + service: service.New(&c.Security, r, b, tc, cipher), } } diff --git a/internal/app/handler/problem.go b/internal/app/handler/problem.go index aa9dceb..1334bc0 100644 --- a/internal/app/handler/problem.go +++ b/internal/app/handler/problem.go @@ -2,111 +2,79 @@ package handler import ( "errors" - "fmt" "net/http" - "strings" "time" - "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" "github.com/voidcontests/api/internal/app/handler/dto/request" "github.com/voidcontests/api/internal/app/handler/dto/response" - "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/app/service" "github.com/voidcontests/api/pkg/validate" ) func (h *Handler) CreateProblem(c echo.Context) error { - op := "handler.CreateProblem" ctx := c.Request().Context() claims, _ := ExtractClaims(c) - var body request.CreateProblemRequest + var body request.CreateProblem if err := validate.Bind(c, &body); err != nil { return Error(http.StatusBadRequest, "invalid body: missing required fields") } - userrole, err := h.repo.User.GetRole(ctx, claims.UserID) + id, err := h.service.Problem.CreateProblem(ctx, service.CreateProblemParams{ + UserID: claims.UserID, + Title: body.Title, + Statement: body.Statement, + Difficulty: body.Difficulty, + TimeLimitMS: body.TimeLimitMS, + MemoryLimitMB: body.MemoryLimitMB, + Checker: body.Checker, + TestCases: body.TestCases, + }) if err != nil { - return fmt.Errorf("%s: can't get role: %v", op, err) - } - - if userrole.Name == models.RoleBanned { - return Error(http.StatusForbidden, "you are banned from creating problems") - } - - if userrole.Name == models.RoleLimited { - pscount, err := h.repo.User.GetCreatedProblemsCount(ctx, claims.UserID) - if err != nil { - return fmt.Errorf("%s: can't get created problems count: %v", op, err) - } - - if pscount >= int(userrole.CreatedProblemsLimit) { + switch { + case errors.Is(err, service.ErrUserBanned): + return Error(http.StatusForbidden, "you are banned from creating problems") + case errors.Is(err, service.ErrProblemsLimitExceeded): return Error(http.StatusForbidden, "problems limit exceeded") + case errors.Is(err, service.ErrInvalidTimeLimit): + return Error(http.StatusBadRequest, "time_limit_ms must be between 500 and 10000") + case errors.Is(err, service.ErrInvalidMemoryLimit): + return Error(http.StatusBadRequest, "memory_limit_mb must be between 16 and 512") + default: + return err } } - if body.TimeLimitMS < 500 || body.TimeLimitMS > 10000 { - return Error(http.StatusBadRequest, "time_limit_ms must be between 500 and 10000") - } - - if body.MemoryLimitMB < 16 || body.MemoryLimitMB > 512 { - return Error(http.StatusBadRequest, "memory_limit_mb must be between 16 and 512") - } - - // TODO: Remove examples as database entity - // Forbid to create more examples than 3 - examplesCount := 0 - for i := range body.TestCases { - if body.TestCases[i].IsExample { - examplesCount++ - } - - if examplesCount > 3 && body.TestCases[i].IsExample { - body.TestCases[i].IsExample = false - } - } - - checker := body.Checker - if checker == "" { - checker = "tokens" - } - - problemID, err := h.repo.Problem.CreateWithTCs(ctx, claims.UserID, body.Title, body.Statement, body.Difficulty, body.TimeLimitMS, body.MemoryLimitMB, checker, body.TestCases) - - if err != nil { - return fmt.Errorf("%s: can't create problem: %v", op, err) - } - return c.JSON(http.StatusCreated, response.ID{ - ID: problemID, + ID: id, }) } func (h *Handler) GetCreatedProblems(c echo.Context) error { - op := "handler.GetCreatedProblems" ctx := c.Request().Context() claims, _ := ExtractClaims(c) limit, ok := ExtractQueryParamInt(c, "limit") - if !ok { + if !ok || limit < 0 { limit = 10 } offset, ok := ExtractQueryParamInt(c, "offset") - if !ok { + if !ok || offset < 0 { offset = 0 } - ps, total, err := h.repo.Problem.GetWithWriterID(ctx, claims.UserID, limit, offset) + result, err := h.service.Problem.GetCreatedProblems(ctx, claims.UserID, limit, offset) if err != nil { - return fmt.Errorf("%s: can't get created problems: %v", op, err) + return err } - n := len(ps) + n := len(result.Problems) problems := make([]response.ProblemListItem, n, n) - for i, p := range ps { + for i, p := range result.Problems { problems[i] = response.ProblemListItem{ ID: p.ID, Title: p.Title, @@ -118,16 +86,17 @@ func (h *Handler) GetCreatedProblems(c echo.Context) error { Writer: response.User{ ID: p.WriterID, Username: p.WriterUsername, + Address: p.WriterAddress, }, } } return c.JSON(http.StatusOK, response.Pagination[response.ProblemListItem]{ Meta: response.Meta{ - Total: total, + Total: result.Total, Limit: limit, Offset: offset, - HasNext: offset+limit < total, + HasNext: offset+limit < result.Total, HasPrev: offset > 0, }, Items: problems, @@ -135,7 +104,6 @@ func (h *Handler) GetCreatedProblems(c echo.Context) error { } func (h *Handler) GetContestProblem(c echo.Context) error { - op := "handler.GetContestProblem" ctx := c.Request().Context() claims, _ := ExtractClaims(c) @@ -146,68 +114,46 @@ func (h *Handler) GetContestProblem(c echo.Context) error { } charcode := c.Param("charcode") - if len(charcode) > 2 { - return Error(http.StatusBadRequest, "problem charcode couldn't be longer than 2 characters") - } - charcode = strings.ToUpper(charcode) - contest, err := h.repo.Contest.GetByID(ctx, int32(contestID)) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "contest not found") - } + details, err := h.service.Problem.GetContestProblem(ctx, contestID, claims.UserID, charcode) if err != nil { - return err - } - - now := time.Now() - if contest.StartTime.After(now) { - return Error(http.StatusForbidden, "contest not started yet") - } - - entry, err := h.repo.Entry.Get(ctx, int32(contestID), claims.UserID) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusForbidden, "no entry") - } - if err != nil { - return fmt.Errorf("%s: can't get entry: %v", op, err) - } - - p, err := h.repo.Problem.Get(ctx, int32(contestID), charcode) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "problem not found") - } - if err != nil { - return fmt.Errorf("%s: can't get problem: %v", op, err) - } - - etc, err := h.repo.Problem.GetExampleCases(ctx, p.ID) - if err != nil { - return fmt.Errorf("%s: can't get tc examples: %v", op, err) + switch { + case errors.Is(err, service.ErrEntryNotPaid): + return Error(http.StatusForbidden, "entry not paid") + case errors.Is(err, service.ErrInvalidCharcode): + return Error(http.StatusBadRequest, "problem charcode couldn't be longer than 2 characters") + case errors.Is(err, service.ErrContestNotFound): + return Error(http.StatusNotFound, "contest not found") + case errors.Is(err, service.ErrContestNotStarted): + return Error(http.StatusForbidden, "contest not started yet") + case errors.Is(err, service.ErrNoEntryForContest): + return Error(http.StatusForbidden, "no entry") + case errors.Is(err, service.ErrProblemNotFound): + return Error(http.StatusNotFound, "problem not found") + default: + return err + } } - n := len(etc) + p := details.Problem + n := len(details.Examples) examples := make([]response.TC, n, n) for i := 0; i < n; i++ { examples[i] = response.TC{ - Input: etc[i].Input, - Output: etc[i].Output, + Input: details.Examples[i].Input, + Output: details.Examples[i].Output, } } - status, err := h.repo.Submission.GetProblemStatus(ctx, entry.ID, p.ID) - if err != nil { - return err - } - pdetailed := response.ContestProblemDetailed{ ID: p.ID, Charcode: p.Charcode, - ContestID: int32(contestID), + ContestID: contestID, Title: p.Title, Statement: p.Statement, Examples: examples, Difficulty: p.Difficulty, - Status: status, + Status: details.Status, CreatedAt: p.CreatedAt, TimeLimitMS: p.TimeLimitMS, MemoryLimitMB: p.MemoryLimitMB, @@ -215,19 +161,18 @@ func (h *Handler) GetContestProblem(c echo.Context) error { Writer: response.User{ ID: p.WriterID, Username: p.WriterUsername, + Address: p.WriterAddress, }, } - _, deadline := AllowSubmitAt(contest, entry) - if contest.StartTime.Before(time.Now()) { - pdetailed.SubmissionDeadline = &deadline + if details.SubmissionWindow.Earliest.Before(time.Now()) { + pdetailed.SubmissionDeadline = &details.SubmissionWindow.Deadline } return c.JSON(http.StatusOK, pdetailed) } func (h *Handler) GetProblemByID(c echo.Context) error { - op := "handler.GetProblemByID" ctx := c.Request().Context() claims, _ := ExtractClaims(c) @@ -237,29 +182,25 @@ func (h *Handler) GetProblemByID(c echo.Context) error { return Error(http.StatusBadRequest, "problem ID should be an integer") } - problem, err := h.repo.Problem.GetByID(ctx, int32(problemID)) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "problem not found") - } + details, err := h.service.Problem.GetProblemByID(ctx, problemID, claims.UserID) if err != nil { - return fmt.Errorf("%s: can't get problem: %v", op, err) - } - - if problem.WriterID != claims.UserID { - return Error(http.StatusNotFound, "problem not found") - } - - etc, err := h.repo.Problem.GetExampleCases(ctx, problem.ID) - if err != nil { - return fmt.Errorf("%s: can't get tc examples: %v", op, err) + switch { + case errors.Is(err, service.ErrProblemNotFound): + return Error(http.StatusNotFound, "problem not found") + case errors.Is(err, service.ErrNotProblemWriter): + return Error(http.StatusNotFound, "problem not found") + default: + return err + } } - n := len(etc) + problem := details.Problem + n := len(details.Examples) examples := make([]response.TC, n, n) for i := 0; i < n; i++ { examples[i] = response.TC{ - Input: etc[i].Input, - Output: etc[i].Output, + Input: details.Examples[i].Input, + Output: details.Examples[i].Output, } } @@ -276,6 +217,7 @@ func (h *Handler) GetProblemByID(c echo.Context) error { Writer: response.User{ ID: problem.WriterID, Username: problem.WriterUsername, + Address: problem.WriterAddress, }, } diff --git a/internal/app/handler/submission.go b/internal/app/handler/submission.go index 14cf7bc..8bc2e11 100644 --- a/internal/app/handler/submission.go +++ b/internal/app/handler/submission.go @@ -2,24 +2,16 @@ package handler import ( "errors" - "log/slog" "net/http" - "strings" - "time" - "github.com/jackc/pgx/v5" "github.com/labstack/echo/v4" "github.com/voidcontests/api/internal/app/handler/dto/request" "github.com/voidcontests/api/internal/app/handler/dto/response" - "github.com/voidcontests/api/internal/lib/logger/sl" - "github.com/voidcontests/api/internal/storage/models" - "github.com/voidcontests/api/internal/storage/models/status" - "github.com/voidcontests/api/pkg/requestid" + "github.com/voidcontests/api/internal/app/service" "github.com/voidcontests/api/pkg/validate" ) func (h *Handler) CreateSubmission(c echo.Context) error { - log := slog.With(slog.String("op", "handler.CreateSubmission"), slog.String("request_id", requestid.Get(c))) ctx := c.Request().Context() claims, _ := ExtractClaims(c) @@ -30,66 +22,37 @@ func (h *Handler) CreateSubmission(c echo.Context) error { } charcode := c.Param("charcode") - if len(charcode) > 2 { - return Error(http.StatusBadRequest, "problem's `charcode` couldn't be longer than 2 characters") - } - charcode = strings.ToUpper(charcode) - var body request.CreateSubmissionRequest + var body request.CreateSubmission if err := validate.Bind(c, &body); err != nil { - log.Debug("can't decode request body", sl.Err(err)) return Error(http.StatusBadRequest, "invalid body") } - contest, err := h.repo.Contest.GetByID(ctx, int32(contestID)) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "contest not found") - } - if err != nil { - log.Error("can't get contest", sl.Err(err)) - return err - } - - entry, err := h.repo.Entry.Get(ctx, int32(contestID), claims.UserID) - if errors.Is(err, pgx.ErrNoRows) { - log.Debug("trying to create submission without entry") - return Error(http.StatusForbidden, "no entry for contest") - } - if err != nil { - log.Error("can't get entry", sl.Err(err)) - return err - } - - now := time.Now() - earliest, deadline := AllowSubmitAt(contest, entry) - if earliest.After(now) || deadline.Before(now) { - return Error(http.StatusForbidden, "submission window is currently closed") - } - - problem, err := h.repo.Problem.Get(ctx, int32(contestID), charcode) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "problem not found") - } - if err != nil { - log.Error("can't get problem", sl.Err(err)) - return err - } - - s, err := h.repo.Submission.Create(ctx, entry.ID, problem.ID, body.Code, body.Language) + result, err := h.service.Submission.CreateSubmission(ctx, service.CreateSubmissionParams{ + ContestID: contestID, + UserID: claims.UserID, + Charcode: charcode, + Code: body.Code, + Language: body.Language, + }) if err != nil { - log.Error("can't create submission", sl.Err(err)) - return err - } - - // TODO: create initial testing report in database - - if err := h.broker.PublishSubmission(ctx, s); err != nil { - log.Error("can't publish submission", sl.Err(err)) - // TODO: if we can't push submission into execution queue, try to save it to local memory, and try to push later (?) - // - but is it really needed, after some time? - return err + switch { + case errors.Is(err, service.ErrInvalidCharcode): + return Error(http.StatusBadRequest, "problem's `charcode` couldn't be longer than 2 characters") + case errors.Is(err, service.ErrContestNotFound): + return Error(http.StatusNotFound, "contest not found") + case errors.Is(err, service.ErrNoEntryForContest): + return Error(http.StatusForbidden, "no entry for contest") + case errors.Is(err, service.ErrSubmissionWindowClosed): + return Error(http.StatusForbidden, "submission window is currently closed") + case errors.Is(err, service.ErrProblemNotFound): + return Error(http.StatusNotFound, "problem not found") + default: + return err + } } + s := result.Submission return c.JSON(http.StatusCreated, response.Submission{ ID: s.ID, ProblemID: s.ProblemID, @@ -100,53 +63,53 @@ func (h *Handler) CreateSubmission(c echo.Context) error { } func (h *Handler) GetSubmissionByID(c echo.Context) error { - log := slog.With(slog.String("op", "handler.GetSubmissionByID"), slog.String("request_id", requestid.Get(c))) ctx := c.Request().Context() - // TODO: check if submission is submitted by request initiator - _, _ = ExtractClaims(c) + claims, _ := ExtractClaims(c) submissionID, ok := ExtractParamInt(c, "sid") if !ok { return Error(http.StatusBadRequest, "submission ID should be an integer") } - s, err := h.repo.Submission.GetByID(ctx, int32(submissionID)) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusNotFound, "submission not found") - } + details, err := h.service.Submission.GetSubmissionByID(ctx, submissionID, claims.UserID) if err != nil { - log.Error("can't get submissions", sl.Err(err)) - return err + switch { + case errors.Is(err, service.ErrSubmissionNotFound): + return Error(http.StatusNotFound, "submission not found") + case errors.Is(err, service.ErrUnauthorizedAccess): + return Error(http.StatusNotFound, "submission not found") + default: + return err + } } - if s.Status != status.Success { + submission := details.Submission + + // If no testing report, return basic submission info + if details.TestingReport == nil { return c.JSON(http.StatusOK, response.Submission{ - ID: s.ID, - ProblemID: s.ProblemID, - Status: s.Status, - Verdict: s.Verdict, - Code: s.Code, - Language: s.Language, - CreatedAt: s.CreatedAt, + ID: submission.ID, + ProblemID: submission.ProblemID, + Status: submission.Status, + Verdict: submission.Verdict, + Code: submission.Code, + Language: submission.Language, + CreatedAt: submission.CreatedAt, }) - } - tr, err := h.repo.Submission.GetTestingReport(ctx, s.ID) - if err != nil { - log.Error("can't get testing report", sl.Err(err)) - return err - } + tr := details.TestingReport - if tr.FirstFailedTestID == nil { + // If no failed test, return with testing report + if details.FailedTest == nil { return c.JSON(http.StatusOK, response.Submission{ - ID: s.ID, - ProblemID: s.ProblemID, - Status: s.Status, - Verdict: s.Verdict, - Code: s.Code, - Language: s.Language, + ID: submission.ID, + ProblemID: submission.ProblemID, + Status: submission.Status, + Verdict: submission.Verdict, + Code: submission.Code, + Language: submission.Language, TestingReport: &response.TestingReport{ ID: tr.ID, PassedTestsCount: tr.PassedTestsCount, @@ -154,23 +117,19 @@ func (h *Handler) GetSubmissionByID(c echo.Context) error { Stderr: tr.Stderr, CreatedAt: tr.CreatedAt, }, - CreatedAt: s.CreatedAt, + CreatedAt: submission.CreatedAt, }) } - ftc, err := h.repo.Problem.GetTestCaseByID(ctx, *tr.FirstFailedTestID) - if err != nil { - log.Error("can't get test case", sl.Err(err)) - return err - } - + // Return with full testing report including failed test + ftc := details.FailedTest return c.JSON(http.StatusOK, response.Submission{ - ID: s.ID, - ProblemID: s.ProblemID, - Status: s.Status, - Verdict: s.Verdict, - Code: s.Code, - Language: s.Language, + ID: submission.ID, + ProblemID: submission.ProblemID, + Status: submission.Status, + Verdict: submission.Verdict, + Code: submission.Code, + Language: submission.Language, TestingReport: &response.TestingReport{ ID: tr.ID, PassedTestsCount: tr.PassedTestsCount, @@ -183,12 +142,11 @@ func (h *Handler) GetSubmissionByID(c echo.Context) error { Stderr: tr.Stderr, CreatedAt: tr.CreatedAt, }, - CreatedAt: s.CreatedAt, + CreatedAt: submission.CreatedAt, }) } func (h *Handler) GetSubmissions(c echo.Context) error { - log := slog.With(slog.String("op", "handler.GetSubmissions"), slog.String("request_id", requestid.Get(c))) ctx := c.Request().Context() claims, _ := ExtractClaims(c) @@ -199,39 +157,32 @@ func (h *Handler) GetSubmissions(c echo.Context) error { } charcode := c.Param("charcode") - if len(charcode) > 2 { - return Error(http.StatusBadRequest, "problem's `charcode` couldn't be longer than 2 characters") - } - charcode = strings.ToUpper(charcode) limit, ok := ExtractQueryParamInt(c, "limit") - if !ok { + if !ok || limit < 0 { limit = 10 } offset, ok := ExtractQueryParamInt(c, "offset") - if !ok { + if !ok || offset < 0 { offset = 0 } - entry, err := h.repo.Entry.Get(ctx, int32(contestID), claims.UserID) - if errors.Is(err, pgx.ErrNoRows) { - return Error(http.StatusForbidden, "no entry for contest") - } - if err != nil { - log.Error("can't get entry", sl.Err(err)) - return err - } - - submissions, total, err := h.repo.Submission.ListByProblem(ctx, entry.ID, charcode, limit, offset) + result, err := h.service.Submission.ListSubmissions(ctx, contestID, claims.UserID, charcode, limit, offset) if err != nil { - log.Error("can't get submissions", sl.Err(err)) - return err + switch { + case errors.Is(err, service.ErrInvalidCharcode): + return Error(http.StatusBadRequest, "problem's `charcode` couldn't be longer than 2 characters") + case errors.Is(err, service.ErrNoEntryForContest): + return Error(http.StatusForbidden, "no entry for contest") + default: + return err + } } - n := len(submissions) + n := len(result.Submissions) items := make([]response.Submission, n, n) - for i, submission := range submissions { + for i, submission := range result.Submissions { items[i] = response.Submission{ ID: submission.ID, ProblemID: submission.ProblemID, @@ -243,32 +194,12 @@ func (h *Handler) GetSubmissions(c echo.Context) error { return c.JSON(http.StatusOK, response.Pagination[response.Submission]{ Meta: response.Meta{ - Total: total, + Total: result.Total, Limit: limit, Offset: offset, - HasNext: offset+limit < total, + HasNext: offset+limit < result.Total, HasPrev: offset > 0, }, Items: items, }) } - -func AllowSubmitAt(contest models.Contest, entry models.Entry) (earliest time.Time, deadline time.Time) { - if contest.DurationMins == 0 { - return contest.StartTime, contest.EndTime - } - - earliest = entry.CreatedAt - if contest.StartTime.After(earliest) { - earliest = contest.StartTime - } - - personalDeadline := earliest.Add(time.Duration(contest.DurationMins) * time.Minute) - if personalDeadline.Before(contest.EndTime) { - deadline = personalDeadline - } else { - deadline = contest.EndTime - } - - return earliest, deadline -} diff --git a/internal/app/handler/tonproof.go b/internal/app/handler/tonproof.go new file mode 100644 index 0000000..5a5b369 --- /dev/null +++ b/internal/app/handler/tonproof.go @@ -0,0 +1,41 @@ +package handler + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/voidcontests/api/internal/app/handler/dto/request" + "github.com/voidcontests/api/pkg/validate" +) + +func (h *Handler) GeneratePayload(c echo.Context) error { + payload, err := h.service.TonProof.GeneratePayload() + if err != nil { + return err + } + + return c.JSON(http.StatusOK, map[string]any{ + "payload": payload, + }) +} + +func (h *Handler) CheckProof(c echo.Context) error { + ctx := c.Request().Context() + + claims, ok := ExtractClaims(c) + if !ok { + return Error(http.StatusUnauthorized, "user not authenticated") + } + + var tp request.TonProof + if err := validate.Bind(c, &tp); err != nil { + return Error(http.StatusBadRequest, "invalid request body") + } + + err := h.service.TonProof.VerifyProofAndSetAddress(ctx, claims.UserID, tp) + if err != nil { + return Error(http.StatusUnauthorized, "tonproof verification failed") + } + + return c.NoContent(http.StatusOK) +} diff --git a/internal/app/router/router.go b/internal/app/router/router.go index 61f47eb..61a7c5f 100644 --- a/internal/app/router/router.go +++ b/internal/app/router/router.go @@ -1,7 +1,6 @@ package router import ( - "log/slog" "net/http" "time" @@ -9,12 +8,13 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/voidcontests/api/internal/app/handler" "github.com/voidcontests/api/internal/config" - "github.com/voidcontests/api/internal/lib/logger/sl" + "github.com/voidcontests/api/internal/lib/crypto" "github.com/voidcontests/api/internal/storage/broker" "github.com/voidcontests/api/internal/storage/repository" "github.com/voidcontests/api/pkg/ratelimit" "github.com/voidcontests/api/pkg/requestid" "github.com/voidcontests/api/pkg/requestlog" + "github.com/voidcontests/api/pkg/ton" ) type Router struct { @@ -22,35 +22,15 @@ type Router struct { handler *handler.Handler } -func New(c *config.Config, r *repository.Repository, b broker.Broker) *Router { - h := handler.New(c, r, b) +func New(c *config.Config, r *repository.Repository, b broker.Broker, tc *ton.Client, cipher crypto.Cipher) *Router { + h := handler.New(c, r, b, tc, cipher) return &Router{config: c, handler: h} } func (r *Router) InitRoutes() *echo.Echo { router := echo.New() - router.HTTPErrorHandler = func(err error, c echo.Context) { - if he, ok := err.(*echo.HTTPError); ok && (he.Code == http.StatusNotFound || he.Code == http.StatusMethodNotAllowed) { - c.JSON(http.StatusNotFound, map[string]string{ - "message": "resource not found", - }) - return - } - - if ae, ok := err.(*handler.APIError); ok { - slog.Debug("responded with API error", sl.Err(err), slog.String("request_id", requestid.Get(c))) - c.JSON(ae.Status, map[string]any{ - "message": ae.Message, - }) - return - } - - slog.Error("something went wrong", sl.Err(err), slog.String("request_id", requestid.Get(c))) - c.JSON(http.StatusInternalServerError, map[string]any{ - "message": "internal server error", - }) - } + router.HTTPErrorHandler = handler.ErorHTTP router.Use(requestid.New) router.Use(requestlog.Completed) @@ -74,30 +54,40 @@ func (r *Router) InitRoutes() *echo.Echo { }) } + // TODO: update rate limiting logic: + // Current: + // - request -> wait Ns -> request + // + // Expected: + // - [request -> request -> request] - in such window, forbid to make more than M requests + // ^ 0s Ns ^ + api := router.Group("/api") { api.GET("/healthcheck", r.handler.Healthcheck) + tonproof := api.Group("/tonproof") + tonproof.POST("/payload", r.handler.GeneratePayload) + tonproof.POST("/check", r.handler.CheckProof, r.handler.MustIdentify()) + api.GET("/account", r.handler.GetAccount, r.handler.MustIdentify()) - api.POST("/account", r.handler.CreateAccount) - api.POST("/session", r.handler.CreateSession) + api.POST("/account", r.handler.CreateAccount, ratelimit.WithTimeout(5*time.Second)) + api.PATCH("/account", r.handler.UpdateAccount, r.handler.MustIdentify()) + api.POST("/session", r.handler.CreateSession, ratelimit.WithTimeout(2*time.Second)) - // TODO: make this endpoints as filter to general endpoint, like: - // GET /contests?creator_id=69 - // GET /problems?writer_id=420 - api.GET("/creator/contests", r.handler.GetCreatedContests, r.handler.MustIdentify()) - api.GET("/creator/problems", r.handler.GetCreatedProblems, r.handler.MustIdentify()) + api.GET("/account/contests", r.handler.GetCreatedContests, r.handler.MustIdentify()) + api.GET("/account/problems", r.handler.GetCreatedProblems, r.handler.MustIdentify()) - api.POST("/problems", r.handler.CreateProblem, r.handler.MustIdentify()) + api.POST("/problems", r.handler.CreateProblem, ratelimit.WithTimeout(3*time.Second), r.handler.MustIdentify()) api.GET("/problems/:pid", r.handler.GetProblemByID, r.handler.MustIdentify()) api.GET("/contests", r.handler.GetContests) - api.POST("/contests", r.handler.CreateContest, r.handler.MustIdentify()) + api.POST("/contests", r.handler.CreateContest, ratelimit.WithTimeout(3*time.Second), r.handler.MustIdentify()) api.GET("/contests/:cid", r.handler.GetContestByID, r.handler.TryIdentify()) - api.POST("/contests/:cid/entry", r.handler.CreateEntry, r.handler.MustIdentify()) - api.GET("/contests/:cid/leaderboard", r.handler.GetLeaderboard) + api.POST("/contests/:cid/entry", r.handler.CreateEntry, ratelimit.WithTimeout(3*time.Second), r.handler.MustIdentify()) + api.GET("/contests/:cid/scores", r.handler.GetScores) api.GET("/contests/:cid/problems/:charcode", r.handler.GetContestProblem, r.handler.MustIdentify()) api.GET("/contests/:cid/problems/:charcode/submissions", r.handler.GetSubmissions, r.handler.MustIdentify()) diff --git a/internal/app/service/account.go b/internal/app/service/account.go new file mode 100644 index 0000000..65e5a14 --- /dev/null +++ b/internal/app/service/account.go @@ -0,0 +1,121 @@ +package service + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/voidcontests/api/internal/config" + "github.com/voidcontests/api/internal/hasher" + "github.com/voidcontests/api/internal/jwt" + "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository" +) + +type AccountService struct { + config *config.Security + repo *repository.Repository +} + +func NewAccountService(cfg *config.Security, repo *repository.Repository) *AccountService { + return &AccountService{ + config: cfg, + repo: repo, + } +} + +func (s *AccountService) CreateAccount(ctx context.Context, username, password string) (int, error) { + op := "service.AccountService.CreateAccount" + + exists, err := s.repo.User.Exists(ctx, username) + if err != nil { + return 0, fmt.Errorf("%s: can't verify that user exists: %w", op, err) + } + + if exists { + return 0, ErrUserAlreadyExists + } + + passwordHash := hasher.Sha256String([]byte(password), []byte(s.config.Salt)) + + user, err := s.repo.User.Create(ctx, username, passwordHash) + if err != nil { + return 0, fmt.Errorf("%s: failed to create user: %w", op, err) + } + + return user.ID, nil +} + +func (s *AccountService) CreateSession(ctx context.Context, username, password string) (string, error) { + op := "service.AccountService.CreateSession" + + passwordHash := hasher.Sha256String([]byte(password), []byte(s.config.Salt)) + + user, err := s.repo.User.GetByCredentials(ctx, username, passwordHash) + if errors.Is(err, pgx.ErrNoRows) { + return "", ErrInvalidCredentials + } + if err != nil { + return "", fmt.Errorf("%s: failed to get user by credentials: %w", op, err) + } + + token, err := jwt.GenerateToken(user.ID, s.config.SignatureKey) + if err != nil { + return "", fmt.Errorf("%s: %w: %v", op, ErrTokenGeneration, err) + } + + return token, nil +} + +type AccountInfo struct { + User models.User + Role models.Role +} + +func (s *AccountService) GetAccount(ctx context.Context, userID int) (*AccountInfo, error) { + op := "service.AccountService.GetAccount" + + user, err := s.repo.User.GetByID(ctx, userID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrInvalidToken + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get user: %w", op, err) + } + + role, err := s.repo.User.GetRole(ctx, userID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get role: %w", op, err) + } + + return &AccountInfo{ + User: user, + Role: role, + }, nil +} + +func (s *AccountService) UpdateAccount(ctx context.Context, userID int, params models.UpdateUserParams) (*models.User, error) { + op := "service.AccountService.UpdateAccount" + + if params.Username != nil { + existingUser, err := s.repo.User.GetByUsername(ctx, *params.Username) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return nil, fmt.Errorf("%s: failed to check username availability: %w", op, err) + } + + if err == nil && existingUser.ID != userID { + return nil, ErrUserAlreadyExists + } + } + + user, err := s.repo.User.UpdateUser(ctx, userID, params) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrInvalidToken + } + if err != nil { + return nil, fmt.Errorf("%s: failed to update user: %w", op, err) + } + + return &user, nil +} diff --git a/internal/app/service/contest.go b/internal/app/service/contest.go new file mode 100644 index 0000000..7bc68ba --- /dev/null +++ b/internal/app/service/contest.go @@ -0,0 +1,505 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/voidcontests/api/internal/lib/crypto" + "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/models/award" + "github.com/voidcontests/api/internal/storage/repository" + "github.com/voidcontests/api/pkg/ton" + "github.com/xssnick/tonutils-go/address" + "github.com/xssnick/tonutils-go/tlb" +) + +type ContestService struct { + repo *repository.Repository + ton *ton.Client + cipher crypto.Cipher +} + +func NewContestService(repo *repository.Repository, tc *ton.Client, cipher crypto.Cipher) *ContestService { + return &ContestService{ + repo: repo, + ton: tc, + cipher: cipher, + } +} + +// CreateContestParams contains parameters for creating a contest +type CreateContestParams struct { + UserID int + Title string + Description string + AwardType string + EntryPriceTonNanos uint64 + StartTime time.Time + EndTime time.Time + DurationMins int + MaxEntries int + AllowLateJoin bool + ProblemIDs []int +} + +func (s *ContestService) CreateContest(ctx context.Context, params CreateContestParams) (int, error) { + op := "service.ContestService.CreateContest" + + now := time.Now() + if params.StartTime.Before(now) { + return 0, ErrInvalidContestTiming + } + if !params.StartTime.Before(params.EndTime) { + return 0, ErrInvalidContestTiming + } + if params.DurationMins < 0 { + return 0, ErrInvalidContestTiming + } + + contestLengthMins := int(params.EndTime.Sub(params.StartTime).Minutes()) + if params.DurationMins > 0 && params.DurationMins > contestLengthMins { + return 0, ErrInvalidContestTiming + } + + userRole, err := s.repo.User.GetRole(ctx, params.UserID) + if err != nil { + return 0, fmt.Errorf("%s: failed to get user role: %w", op, err) + } + + if userRole.Name == models.RoleBanned { + return 0, ErrUserBanned + } + + if userRole.Name == models.RoleLimited { + contestsCount, err := s.repo.User.GetCreatedContestsCount(ctx, params.UserID) + if err != nil { + return 0, fmt.Errorf("%s: failed to get created contests count: %w", op, err) + } + + if contestsCount >= userRole.CreatedContestsLimit { + return 0, ErrContestsLimitExceeded + } + } + + const charcodes = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + if len(params.ProblemIDs) > len(charcodes) { + return 0, fmt.Errorf("too many problems: got %d, max %d", len(params.ProblemIDs), len(charcodes)) + } + + problems := make([]models.ProblemCharcode, len(params.ProblemIDs)) + for i, problemID := range params.ProblemIDs { + problems[i] = models.ProblemCharcode{ + ProblemID: problemID, + Charcode: string(charcodes[i]), + } + } + + // NOTE: if award type is not `paid_entry` or `sponsored` - use `no_prize` by default + var contestID int + switch params.AwardType { + case award.Sponsored, award.Pool: + w, err := s.ton.CreateWallet() + if err != nil { + return 0, err + } + + address := w.Address().String() + mnemonic := strings.Join(w.Mnemonic, " ") + + // Encrypt the mnemonic before storing + encryptedMnemonic, err := s.cipher.Encrypt(mnemonic) + if err != nil { + return 0, fmt.Errorf("%s: failed to encrypt mnemonic: %w", op, err) + } + + err = s.repo.TxManager.WithinTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { + repo := repository.NewTxRepository(tx) + + walletID, err := repo.Wallet.Create(ctx, address, encryptedMnemonic) + if err != nil { + return fmt.Errorf("create wallet: %w", err) + } + + contestID, err = repo.Contest.Create( + ctx, + params.UserID, + params.Title, + params.Description, + params.AwardType, + params.EntryPriceTonNanos, + params.StartTime, + params.EndTime, + params.DurationMins, + params.MaxEntries, + params.AllowLateJoin, + problems, + &walletID, + ) + if err != nil { + return fmt.Errorf("create contest: %w", err) + } + + return nil + }) + if err != nil { + return 0, fmt.Errorf("%s: failed to create contest: %w", op, err) + } + case award.No: + contestID, err = s.repo.Contest.Create( + ctx, + params.UserID, + params.Title, + params.Description, + award.No, + 0, + params.StartTime, + params.EndTime, + params.DurationMins, + params.MaxEntries, + params.AllowLateJoin, + problems, + nil, + ) + if err != nil { + return 0, fmt.Errorf("%s: failed to create contest: %w", op, err) + } + default: + return 0, ErrUnknownAwardType + } + + return contestID, nil +} + +type ContestDetails struct { + Contest models.Contest + IsRegistrationOpen bool + Problems []models.Problem + IsParticipant bool + ProblemStatuses map[int]string + WalletAddress string + PrizeNanosTON uint64 + EntryDetails *EntryDetails + DistributionPayment *models.Payment +} + +func (s *ContestService) GetContestByID(ctx context.Context, contestID int, userID int, authenticated bool) (*ContestDetails, error) { + op := "service.ContestService.GetContestByID" + + contest, err := s.repo.Contest.GetByID(ctx, contestID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrContestNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get contest: %w", op, err) + } + + problems, err := s.repo.Contest.GetProblemset(ctx, contestID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get problemset: %w", op, err) + } + + now := time.Now() + + // Registration is closed if: + // 1. Contest has ended + // 2. Contest has started and late join is not allowed + isRegistrationOpen := true + if contest.EndTime.Before(now) || (contest.StartTime.Before(now) && !contest.AllowLateJoin) { + isRegistrationOpen = false + } + details := &ContestDetails{ + IsRegistrationOpen: isRegistrationOpen, + Contest: contest, + Problems: problems, + } + + if contest.WalletID != nil && (contest.AwardType == award.Pool || contest.AwardType == award.Sponsored) { + wallet, err := s.repo.Contest.GetWallet(ctx, *contest.WalletID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get wallet: %w", op, err) + } + + details.WalletAddress = wallet.Address + + addr, err := address.ParseAddr(wallet.Address) + if err != nil { + return nil, fmt.Errorf("%s: failed to parse wallet address: %w", op, err) + } + + details.PrizeNanosTON, err = s.ton.GetBalanceCached(ctx, addr) + if err != nil { + // TODO: maybe on this error, just return balance = 0 (?) + return nil, fmt.Errorf("%s: failed to get wallet balance: %w", op, err) + } + } + + if contest.DistributionPaymentID != nil { + payment, err := s.repo.Payment.GetByID(ctx, *contest.DistributionPaymentID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get distribution payment: %w", op, err) + } + details.DistributionPayment = &payment + } + + if !authenticated { + return details, nil + } + + entry, err := s.repo.Entry.Get(ctx, contestID, userID) + if errors.Is(err, pgx.ErrNoRows) { + return details, nil + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get entry: %w", op, err) + } + + details.IsParticipant = true + + statuses, err := s.repo.Submission.GetProblemStatuses(ctx, entry.ID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get problem statuses: %w", op, err) + } + details.ProblemStatuses = statuses + + entryDetails, err := s.getEntryDetails(ctx, entry, contest) + if err != nil { + return nil, fmt.Errorf("%s: failed to get entry details: %w", op, err) + } + details.EntryDetails = &entryDetails + + if details.EntryDetails != nil { + _, deadline := CalculateSubmissionWindow(contest, entry) + if contest.StartTime.Before(now) { + details.EntryDetails.SubmissionDeadline = &deadline + } + } + + return details, nil +} + +type ListContestsResult struct { + Contests []models.Contest + Total int +} + +func (s *ContestService) ListCreatedContests(ctx context.Context, creatorID int, limit, offset int) (*ListContestsResult, error) { + op := "service.ContestService.ListCreatedContests" + + contests, total, err := s.repo.Contest.GetWithCreatorID(ctx, creatorID, limit, offset) + if err != nil { + return nil, fmt.Errorf("%s: failed to get created contests: %w", op, err) + } + + return &ListContestsResult{ + Contests: contests, + Total: total, + }, nil +} + +func (s *ContestService) ListAllContests(ctx context.Context, limit, offset int, filters models.ContestFilters) (*ListContestsResult, error) { + op := "service.ContestService.ListAllContests" + + contests, total, err := s.repo.Contest.ListAll(ctx, limit, offset, filters) + if err != nil { + return nil, fmt.Errorf("%s: failed to list all contests: %w", op, err) + } + + return &ListContestsResult{ + Contests: contests, + Total: total, + }, nil +} + +type ScoresResult struct { + Scores []models.ScoresEntry + Total int +} + +func (s *ContestService) GetScores(ctx context.Context, contestID int, limit, offset int) (*ScoresResult, error) { + op := "service.ContestService.GetScores" + + _, err := s.repo.Contest.GetByID(ctx, contestID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrContestNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get contest: %w", op, err) + } + + scores, total, err := s.repo.Contest.GetScores(ctx, contestID, limit, offset) + if err != nil { + return nil, fmt.Errorf("%s: failed to get scores: %w", op, err) + } + + return &ScoresResult{ + Scores: scores, + Total: total, + }, nil +} + +type EntryDetails struct { + Entry models.Entry + IsAdmitted bool + SubmissionDeadline *time.Time + Message string + Payment *models.Payment +} + +func (s *ContestService) getEntryDetails(ctx context.Context, entry models.Entry, contest models.Contest) (EntryDetails, error) { + op := "service.ContestService.getEntryDetails" + + if entry.PaymentID != nil { + payment, err := s.repo.Payment.GetByID(ctx, *entry.PaymentID) + if err != nil { + return EntryDetails{}, fmt.Errorf("%s: failed to get payment: %w", op, err) + } + + return EntryDetails{ + Entry: entry, + IsAdmitted: true, + Payment: &payment, + }, nil + } + + if contest.AwardType != award.Pool { + return EntryDetails{ + Entry: entry, + IsAdmitted: true, + }, nil + } + + user, err := s.repo.User.GetByID(ctx, entry.UserID) + if err != nil { + return EntryDetails{}, fmt.Errorf("%s: failed to get user: %w", op, err) + } + + if user.Address == "" { + return EntryDetails{ + Entry: entry, + IsAdmitted: false, + Message: "address connected required to user account, to check payment", + }, nil + } + + from, err := address.ParseAddr(user.Address) + if err != nil { + return EntryDetails{}, fmt.Errorf("%s: failed to parse user address: %w", op, err) + } + + if contest.WalletID == nil { + return EntryDetails{}, errors.New("prized contest has no wallet") + } + + wallet, err := s.repo.Contest.GetWallet(ctx, *contest.WalletID) + if err != nil { + return EntryDetails{}, fmt.Errorf("%s: failed to get wallet: %w", op, err) + } + + to, err := address.ParseAddr(wallet.Address) + if err != nil { + return EntryDetails{}, fmt.Errorf("%s: failed to parse user address: %w", op, err) + } + + amount := tlb.FromNanoTONU(contest.EntryPriceTonNanos) + transaction, exists := s.ton.LookupTx(ctx, from, to, amount) + if !exists { + return EntryDetails{ + Entry: entry, + IsAdmitted: false, + Message: "payment required to participate in this contest", + }, nil + } + + var payment models.Payment + err = s.repo.TxManager.WithinTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { + repo := repository.NewTxRepository(tx) + + pid, err := repo.Payment.Create(ctx, transaction, s.ton.GetAddress(from), wallet.Address, contest.EntryPriceTonNanos, true) + if err != nil { + return fmt.Errorf("failed to create payment: %w", err) + } + + err = repo.Entry.SetPaymentID(ctx, entry.ID, pid) + if err != nil { + return fmt.Errorf("failed to set payment ID for entry: %w", err) + } + entry.PaymentID = &pid + + payment, err = repo.Payment.GetByID(ctx, pid) + if err != nil { + return fmt.Errorf("failed to get payment by ID: %w", err) + } + + return nil + }) + + if err != nil { + return EntryDetails{}, fmt.Errorf("%s: %w", op, err) + } + + return EntryDetails{ + Entry: entry, + IsAdmitted: true, + Payment: &payment, + }, nil +} + +func (s *ContestService) CreateEntry(ctx context.Context, contestID int, userID int) (int, error) { + op := "service.ContestService.CreateEntry" + + contest, err := s.repo.Contest.GetByID(ctx, contestID) + if errors.Is(err, pgx.ErrNoRows) { + return 0, ErrContestNotFound + } + if err != nil { + return 0, fmt.Errorf("%s: failed to get contest: %w", op, err) + } + + if contest.CreatorID == userID { + return 0, ErrCannotJoinOwnContest + } + + now := time.Now() + if contest.EndTime.Before(now) || (contest.StartTime.Before(now) && !contest.AllowLateJoin) { + return 0, ErrApplicationTimeOver + } + + var entryID int + err = s.repo.TxManager.WithinTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { + repo := repository.NewTxRepository(tx) + + _, err := repo.Entry.Get(ctx, contestID, userID) + if err == nil { + return ErrEntryAlreadyExists + } + if !errors.Is(err, pgx.ErrNoRows) { + return fmt.Errorf("failed to check existing entry: %w", err) + } + + entriesCount, err := repo.Contest.GetEntriesCount(ctx, contestID) + if err != nil { + return fmt.Errorf("failed to get entries count: %w", err) + } + + if contest.MaxEntries != 0 && entriesCount >= contest.MaxEntries { + return ErrMaxSlotsReached + } + + entryID, err = repo.Entry.Create(ctx, contestID, userID) + if err != nil { + return fmt.Errorf("failed to create entry: %w", err) + } + + return nil + }) + + if err != nil { + return 0, fmt.Errorf("%s: %w", op, err) + } + + return entryID, nil +} diff --git a/internal/app/service/errors.go b/internal/app/service/errors.go new file mode 100644 index 0000000..9685f0f --- /dev/null +++ b/internal/app/service/errors.go @@ -0,0 +1,46 @@ +package service + +import "errors" + +var ( + // account + ErrUserAlreadyExists = errors.New("user already exists") + ErrInvalidCredentials = errors.New("invalid credentials") + ErrUserNotFound = errors.New("user not found") + ErrTokenGeneration = errors.New("failed to generate token") + ErrInvalidToken = errors.New("invalid or expired token") + + // contest + ErrUnknownAwardType = errors.New("unknown award type") + ErrInvalidContestTiming = errors.New("invalid contest timing") + + // entry + ErrContestFinished = errors.New("contest not found") + ErrContestNotFound = errors.New("contest not found") + ErrCannotJoinOwnContest = errors.New("cannot join your own contest") + ErrMaxSlotsReached = errors.New("max slots limit reached") + ErrApplicationTimeOver = errors.New("application time is over") + ErrEntryAlreadyExists = errors.New("user already has entry for this contest") + ErrEntryNotFound = errors.New("entry not found") + ErrEntryNotPaid = errors.New("entry not paid") + + // problem + ErrUserBanned = errors.New("you are banned from creating problems") + ErrProblemsLimitExceeded = errors.New("problems limit exceeded") + ErrContestsLimitExceeded = errors.New("contests limit exceeded") + ErrInvalidTimeLimit = errors.New("time_limit_ms must be between 500 and 10000") + ErrInvalidMemoryLimit = errors.New("memory_limit_mb must be between 16 and 512") + ErrContestNotStarted = errors.New("contest not started yet") + ErrNotProblemWriter = errors.New("you are not the problem writer") + + // submissions + ErrProblemNotFound = errors.New("problem not found") + ErrNoEntryForContest = errors.New("no entry for contest") + ErrSubmissionWindowClosed = errors.New("submission window is currently closed") + ErrSubmissionNotFound = errors.New("submission not found") + ErrInvalidCharcode = errors.New("problem's charcode couldn't be longer than 2 characters") + ErrUnauthorizedAccess = errors.New("unauthorized access to this resource") + + // tonproof + ErrTonProofFailed = errors.New("tonproof verification failed") +) diff --git a/internal/app/service/problem.go b/internal/app/service/problem.go new file mode 100644 index 0000000..af6d4a1 --- /dev/null +++ b/internal/app/service/problem.go @@ -0,0 +1,242 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/models/award" + "github.com/voidcontests/api/internal/storage/repository" +) + +type ProblemService struct { + repo *repository.Repository +} + +func NewProblemService(repo *repository.Repository) *ProblemService { + return &ProblemService{ + repo: repo, + } +} + +type CreateProblemParams struct { + UserID int + Title string + Statement string + Difficulty string + TimeLimitMS int + MemoryLimitMB int + Checker string + TestCases []models.TestCaseDTO +} + +func (s *ProblemService) CreateProblem(ctx context.Context, params CreateProblemParams) (int, error) { + op := "service.ProblemService.CreateProblem" + + userRole, err := s.repo.User.GetRole(ctx, params.UserID) + if err != nil { + return 0, fmt.Errorf("%s: failed to get user role: %w", op, err) + } + + if userRole.Name == models.RoleBanned { + return 0, ErrUserBanned + } + + if userRole.Name == models.RoleLimited { + problemsCount, err := s.repo.User.GetCreatedProblemsCount(ctx, params.UserID) + if err != nil { + return 0, fmt.Errorf("%s: failed to get created problems count: %w", op, err) + } + + if problemsCount >= userRole.CreatedProblemsLimit { + return 0, ErrProblemsLimitExceeded + } + } + + if params.TimeLimitMS < 500 || params.TimeLimitMS > 10000 { + return 0, ErrInvalidTimeLimit + } + + if params.MemoryLimitMB < 16 || params.MemoryLimitMB > 512 { + return 0, ErrInvalidMemoryLimit + } + + examplesCount := 0 + for i := range params.TestCases { + if params.TestCases[i].IsExample { + examplesCount++ + } + + if examplesCount > 3 && params.TestCases[i].IsExample { + params.TestCases[i].IsExample = false + } + } + + checker := params.Checker + if checker == "" { + checker = "tokens" + } + + const MAX_TEST_CASES = 100 + if len(params.TestCases) > MAX_TEST_CASES { + return 0, fmt.Errorf("too many test cases: got %d, max %d", len(params.TestCases), MAX_TEST_CASES) + } + + for i, tc := range params.TestCases { + if tc.Input == "" && tc.Output == "" { + return 0, fmt.Errorf("test case %d: both input and output are empty", i) + } + } + + var problemID int + err = s.repo.TxManager.WithinTransaction(ctx, func(ctx context.Context, tx pgx.Tx) error { + repo := repository.NewTxRepository(tx) + + problemID, err = repo.Problem.Create(ctx, params.UserID, params.Title, params.Statement, params.Difficulty, params.TimeLimitMS, params.MemoryLimitMB, checker) + if err != nil { + return fmt.Errorf("create problem: %w", err) + } + + err = repo.Problem.AssociateTestCases(ctx, problemID, params.TestCases) + if err != nil { + return fmt.Errorf("associate test cases: %w", err) + } + + return nil + }) + if err != nil { + return 0, fmt.Errorf("%s: failed to create problem: %w", op, err) + } + + return problemID, nil +} + +type ListProblemsResult struct { + Problems []models.Problem + Total int +} + +func (s *ProblemService) GetCreatedProblems(ctx context.Context, writerID int, limit, offset int) (*ListProblemsResult, error) { + op := "service.ProblemService.GetCreatedProblems" + + problems, total, err := s.repo.Problem.GetWithWriterID(ctx, writerID, limit, offset) + if err != nil { + return nil, fmt.Errorf("%s: failed to get created problems: %w", op, err) + } + + return &ListProblemsResult{ + Problems: problems, + Total: total, + }, nil +} + +type ContestProblemDetails struct { + Problem models.Problem + Examples []models.TestCase + Status string + SubmissionWindow SubmissionWindow +} + +type SubmissionWindow struct { + Earliest time.Time + Deadline time.Time +} + +func (s *ProblemService) GetContestProblem(ctx context.Context, contestID int, userID int, charcode string) (*ContestProblemDetails, error) { + op := "service.ProblemService.GetContestProblem" + + if len(charcode) > 2 { + return nil, ErrInvalidCharcode + } + charcode = strings.ToUpper(charcode) + + contest, err := s.repo.Contest.GetByID(ctx, contestID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrContestNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get contest: %w", op, err) + } + + now := time.Now() + if contest.StartTime.After(now) { + return nil, ErrContestNotStarted + } + + entry, err := s.repo.Entry.Get(ctx, contestID, userID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNoEntryForContest + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get entry: %w", op, err) + } + + if contest.AwardType == award.Pool && entry.PaymentID == nil { + return nil, ErrEntryNotPaid + } + + problem, err := s.repo.Problem.Get(ctx, contestID, charcode) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrProblemNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get problem: %w", op, err) + } + + examples, err := s.repo.Problem.GetExampleCases(ctx, problem.ID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get example cases: %w", op, err) + } + + status, err := s.repo.Submission.GetProblemStatus(ctx, entry.ID, problem.ID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get problem status: %w", op, err) + } + + earliest, deadline := CalculateSubmissionWindow(contest, entry) + + return &ContestProblemDetails{ + Problem: problem, + Examples: examples, + Status: status, + SubmissionWindow: SubmissionWindow{ + Earliest: earliest, + Deadline: deadline, + }, + }, nil +} + +type ProblemDetails struct { + Problem models.Problem + Examples []models.TestCase +} + +func (s *ProblemService) GetProblemByID(ctx context.Context, problemID int, userID int) (*ProblemDetails, error) { + op := "service.ProblemService.GetProblemByID" + + problem, err := s.repo.Problem.GetByID(ctx, problemID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrProblemNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get problem: %w", op, err) + } + + if problem.WriterID != userID { + return nil, ErrNotProblemWriter + } + + examples, err := s.repo.Problem.GetExampleCases(ctx, problem.ID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get example cases: %w", op, err) + } + + return &ProblemDetails{ + Problem: problem, + Examples: examples, + }, nil +} diff --git a/internal/app/service/service.go b/internal/app/service/service.go new file mode 100644 index 0000000..4fadb32 --- /dev/null +++ b/internal/app/service/service.go @@ -0,0 +1,27 @@ +package service + +import ( + "github.com/voidcontests/api/internal/config" + "github.com/voidcontests/api/internal/lib/crypto" + "github.com/voidcontests/api/internal/storage/broker" + "github.com/voidcontests/api/internal/storage/repository" + "github.com/voidcontests/api/pkg/ton" +) + +type Service struct { + Account *AccountService + Submission *SubmissionService + Problem *ProblemService + Contest *ContestService + TonProof *TonProofService +} + +func New(cfg *config.Security, repo *repository.Repository, broker broker.Broker, tc *ton.Client, cipher crypto.Cipher) *Service { + return &Service{ + Account: NewAccountService(cfg, repo), + Submission: NewSubmissionService(repo, broker), + Problem: NewProblemService(repo), + Contest: NewContestService(repo, tc, cipher), + TonProof: NewTonProofService(repo, tc), + } +} diff --git a/internal/app/service/submission.go b/internal/app/service/submission.go new file mode 100644 index 0000000..33c738a --- /dev/null +++ b/internal/app/service/submission.go @@ -0,0 +1,190 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/voidcontests/api/internal/storage/broker" + "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/models/status" + "github.com/voidcontests/api/internal/storage/repository" +) + +type SubmissionService struct { + repo *repository.Repository + broker broker.Broker +} + +func NewSubmissionService(repo *repository.Repository, broker broker.Broker) *SubmissionService { + return &SubmissionService{ + repo: repo, + broker: broker, + } +} + +type CreateSubmissionParams struct { + ContestID int + UserID int + Charcode string + Code string + Language string +} + +type CreateSubmissionResult struct { + Submission models.Submission +} + +func (s *SubmissionService) CreateSubmission(ctx context.Context, params CreateSubmissionParams) (*CreateSubmissionResult, error) { + op := "service.SubmissionService.CreateSubmission" + + if len(params.Charcode) > 2 { + return nil, ErrInvalidCharcode + } + charcode := strings.ToUpper(params.Charcode) + + contest, err := s.repo.Contest.GetByID(ctx, params.ContestID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrContestNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get contest: %w", op, err) + } + + entry, err := s.repo.Entry.Get(ctx, params.ContestID, params.UserID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNoEntryForContest + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get entry: %w", op, err) + } + + now := time.Now() + earliest, deadline := CalculateSubmissionWindow(contest, entry) + if earliest.After(now) || deadline.Before(now) { + return nil, ErrSubmissionWindowClosed + } + + problem, err := s.repo.Problem.Get(ctx, params.ContestID, charcode) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrProblemNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get problem: %w", op, err) + } + + submission, err := s.repo.Submission.Create(ctx, entry.ID, problem.ID, params.Code, params.Language) + if err != nil { + return nil, fmt.Errorf("%s: failed to create submission: %w", op, err) + } + + if err := s.broker.PublishSubmission(ctx, submission); err != nil { + return nil, fmt.Errorf("%s: failed to publish submission: %w", op, err) + } + + return &CreateSubmissionResult{ + Submission: submission, + }, nil +} + +type SubmissionDetails struct { + Submission models.Submission + TestingReport *models.TestingReport + FailedTest *models.TestCase +} + +func (s *SubmissionService) GetSubmissionByID(ctx context.Context, submissionID int, userID int) (*SubmissionDetails, error) { + op := "service.SubmissionService.GetSubmissionByID" + + submission, err := s.repo.Submission.GetByID(ctx, submissionID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrSubmissionNotFound + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get submission: %w", op, err) + } + + if submission.UserID != userID { + return nil, ErrUnauthorizedAccess + } + + details := &SubmissionDetails{ + Submission: submission, + } + + if submission.Status != status.Success { + return details, nil + } + + testingReport, err := s.repo.Submission.GetTestingReport(ctx, submission.ID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get testing report: %w", op, err) + } + details.TestingReport = &testingReport + + if testingReport.FirstFailedTestID != nil { + failedTest, err := s.repo.Problem.GetTestCaseByID(ctx, *testingReport.FirstFailedTestID) + if err != nil { + return nil, fmt.Errorf("%s: failed to get failed test case: %w", op, err) + } + details.FailedTest = &failedTest + } + + return details, nil +} + +type ListSubmissionsResult struct { + Submissions []models.Submission + Total int +} + +func (s *SubmissionService) ListSubmissions(ctx context.Context, contestID int, userID int, charcode string, limit, offset int) (*ListSubmissionsResult, error) { + op := "service.SubmissionService.ListSubmissions" + + if len(charcode) > 2 { + return nil, ErrInvalidCharcode + } + charcode = strings.ToUpper(charcode) + + entry, err := s.repo.Entry.Get(ctx, contestID, userID) + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNoEntryForContest + } + if err != nil { + return nil, fmt.Errorf("%s: failed to get entry: %w", op, err) + } + + submissions, total, err := s.repo.Submission.ListByProblem(ctx, entry.ID, charcode, limit, offset) + if err != nil { + return nil, fmt.Errorf("%s: failed to list submissions: %w", op, err) + } + + return &ListSubmissionsResult{ + Submissions: submissions, + Total: total, + }, nil +} + +func CalculateSubmissionWindow(contest models.Contest, entry models.Entry) (earliest time.Time, deadline time.Time) { + if contest.DurationMins == 0 { + return contest.StartTime, contest.EndTime + } + + earliest = entry.CreatedAt + if contest.StartTime.After(earliest) { + earliest = contest.StartTime + } + + personalDeadline := earliest.Add(time.Duration(contest.DurationMins) * time.Minute) + + if personalDeadline.Before(contest.EndTime) { + deadline = personalDeadline + } else { + deadline = contest.EndTime + } + + return earliest, deadline +} diff --git a/internal/app/service/tonproof.go b/internal/app/service/tonproof.go new file mode 100644 index 0000000..5eaf075 --- /dev/null +++ b/internal/app/service/tonproof.go @@ -0,0 +1,98 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + + "github.com/tonkeeper/tongo/tonconnect" + "github.com/voidcontests/api/internal/app/handler/dto/request" + "github.com/voidcontests/api/internal/lib/logger/sl" + "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository" + "github.com/voidcontests/api/pkg/ton" + "github.com/xssnick/tonutils-go/address" +) + +// TODO: Wrap errors + +type TonProofService struct { + repo *repository.Repository + ton *ton.Client + testnet bool +} + +func NewTonProofService(repo *repository.Repository, tc *ton.Client) *TonProofService { + return &TonProofService{ + repo: repo, + ton: tc, + testnet: tc.IsTestnet(), + } +} + +func (s *TonProofService) GeneratePayload() (string, error) { + op := "service.TonProofService.GeneratePayload" + + payload, err := s.ton.TonConnect.GeneratePayload() + if err != nil { + slog.Error("tonproof: failed to generate payload", slog.String("op", op), slog.String("error", err.Error())) + return "", fmt.Errorf("%s: %w", op, err) + } + + slog.Info("tonproof: payload generated", slog.String("payload", payload)) + return payload, nil +} + +func (s *TonProofService) VerifyProofAndSetAddress(ctx context.Context, userID int, tp request.TonProof) error { + op := "service.TonProofService.VerifyProofAndSetAddress" + + expectedNetwork := ton.MainnetID + if s.testnet { + expectedNetwork = ton.TestnetID + } + if tp.Network != expectedNetwork { + return fmt.Errorf("%s: network mismatch", op) + } + + proof := tonconnect.Proof{ + Address: tp.Address, + Proof: tonconnect.ProofData{ + Timestamp: tp.Proof.Timestamp, + Domain: tp.Proof.Domain.Value, + Signature: tp.Proof.Signature, + Payload: tp.Proof.Payload, + StateInit: tp.Proof.StateInit, + }, + } + + allowAnyDomain := func(addr string) (bool, error) { + return true, nil + } + + verified, _, err := s.ton.TonConnect.CheckProof(ctx, &proof, s.ton.TonConnect.CheckPayload, allowAnyDomain) + if err != nil { + slog.Error("tonproof: proof verification failed", slog.String("op", op), sl.Err(err)) + return ErrTonProofFailed + } + if !verified { + slog.Warn("tonproof: proof not verified", slog.String("op", op)) + return ErrTonProofFailed + } + + addr, err := address.ParseRawAddr(tp.Address) + if err != nil { + slog.Warn("tonproof: failed to parse raw address", sl.Err(err)) + return fmt.Errorf("%s: failed to parse address: %w", op, err) + } + + addrStr := addr.Testnet(s.testnet).String() + + _, err = s.repo.User.UpdateUser(ctx, userID, models.UpdateUserParams{ + Address: &addrStr, + }) + if err != nil { + return fmt.Errorf("%s: failed to update user: %w", op, err) + } + + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index e6eb6f0..c53f5cc 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,6 +22,7 @@ type Config struct { Security Security `yaml:"security" env-required:"true"` Postgres Postgres `yaml:"postgres" env-required:"true"` Redis Redis `yaml:"redis" env-required:"true"` + Ton Ton `yaml:"ton" env-required:"true"` } type Server struct { @@ -31,8 +32,9 @@ type Server struct { } type Security struct { - SignatureKey string `yaml:"signature_key" env-required:"true"` - Salt string `yaml:"salt" env-required:"true"` + SignatureKey string `yaml:"signature_key" env-required:"true"` + Salt string `yaml:"salt" env-required:"true"` + WalletEncryptKey string `yaml:"wallet_encrypt_key" env-required:"true"` } type Postgres struct { @@ -51,6 +53,18 @@ type Redis struct { Db int `yaml:"db"` } +type Ton struct { + IsTestnet bool `yaml:"is_testnet"` + ConfigURL string `yaml:"config_url"` + Proof TonProof `yaml:"proof"` +} + +type TonProof struct { + PayloadSignatureKey string `yaml:"payload_signature_key" env-required:"true"` + PayloadLifetime time.Duration `yaml:"payload_lifetime" env-default:"600s"` + ProofLifetime time.Duration `yaml:"proof_lifetime" env-default:"600s"` +} + // MustLoad loads config to a new Config instance and return it func MustLoad() *Config { _ = godotenv.Load() diff --git a/internal/jwt/jwt.go b/internal/jwt/jwt.go index 1529645..2b11235 100644 --- a/internal/jwt/jwt.go +++ b/internal/jwt/jwt.go @@ -9,13 +9,14 @@ import ( type CustomClaims struct { jwt.RegisteredClaims - UserID int32 `json:"id"` + UserID int `json:"id"` } -func GenerateToken(id int32, secret string) (string, error) { +func GenerateToken(id int, secret string) (string, error) { + expiration := 7 * 24 * time.Hour // 7 days claims := &CustomClaims{ jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().AddDate(100, 0, 0)), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiration)), }, id, } @@ -30,7 +31,7 @@ func GenerateToken(id int32, secret string) (string, error) { return signedToken, nil } -func Parse(token, secret string) (id int32, err error) { +func Parse(token, secret string) (id int, err error) { jsonwebtoken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) diff --git a/internal/lib/crypto/crypto.go b/internal/lib/crypto/crypto.go new file mode 100644 index 0000000..497ccf8 --- /dev/null +++ b/internal/lib/crypto/crypto.go @@ -0,0 +1,86 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "io" +) + +type Cipher interface { + Encrypt(plaintext string) (string, error) + Decrypt(ciphertext string) (string, error) +} + +func NewCipher(key string) Cipher { + hash := sha256.Sum256([]byte(key)) + return &encryptor{ + key: hash[:], + } +} + +type encryptor struct { + key []byte +} + +func (e *encryptor) Encrypt(plaintext string) (string, error) { + if plaintext == "" { + return "", fmt.Errorf("plaintext cannot be empty") + } + + block, err := aes.NewCipher(e.key) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("failed to generate nonce: %w", err) + } + + ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) + + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func (e *encryptor) Decrypt(ciphertext string) (string, error) { + if ciphertext == "" { + return "", fmt.Errorf("ciphertext cannot be empty") + } + + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("failed to decode base64: %w", err) + } + + block, err := aes.NewCipher(e.key) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("failed to create GCM: %w", err) + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return "", fmt.Errorf("ciphertext too short") + } + + nonce, encryptedData := data[:nonceSize], data[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, encryptedData, nil) + if err != nil { + return "", fmt.Errorf("failed to decrypt: %w", err) + } + + return string(plaintext), nil +} diff --git a/internal/lib/crypto/crypto_test.go b/internal/lib/crypto/crypto_test.go new file mode 100644 index 0000000..18d638b --- /dev/null +++ b/internal/lib/crypto/crypto_test.go @@ -0,0 +1,150 @@ +package crypto + +import ( + "strings" + "testing" +) + +func TestEncryptDecrypt(t *testing.T) { + encryptor := NewCipher("test-encryption-key-12345") + + tests := []struct { + name string + plaintext string + }{ + { + name: "simple string", + plaintext: "hello world", + }, + { + name: "mnemonic phrase", + plaintext: "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about", + }, + { + name: "long string", + plaintext: strings.Repeat("a", 1000), + }, + { + name: "special characters", + plaintext: "!@#$%^&*()_+-={}[]|\\:\";<>?,./", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := encryptor.Encrypt(tt.plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + if ciphertext == "" { + t.Fatal("Encrypt() returned empty ciphertext") + } + + if ciphertext == tt.plaintext { + t.Fatal("Encrypt() returned plaintext unchanged") + } + + decrypted, err := encryptor.Decrypt(ciphertext) + if err != nil { + t.Fatalf("Decrypt() error = %v", err) + } + + if decrypted != tt.plaintext { + t.Errorf("Decrypt() = %v, want %v", decrypted, tt.plaintext) + } + }) + } +} + +func TestEncryptEmptyString(t *testing.T) { + encryptor := NewCipher("test-key") + + _, err := encryptor.Encrypt("") + if err == nil { + t.Error("Encrypt() should return error for empty string") + } +} + +func TestDecryptEmptyString(t *testing.T) { + encryptor := NewCipher("test-key") + + _, err := encryptor.Decrypt("") + if err == nil { + t.Error("Decrypt() should return error for empty string") + } +} + +func TestDecryptInvalidCiphertext(t *testing.T) { + encryptor := NewCipher("test-key") + + tests := []struct { + name string + ciphertext string + }{ + { + name: "invalid base64", + ciphertext: "not-valid-base64!@#", + }, + { + name: "too short", + ciphertext: "YWJj", + }, + { + name: "random data", + ciphertext: "SGVsbG8gV29ybGQ=", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := encryptor.Decrypt(tt.ciphertext) + if err == nil { + t.Error("Decrypt() should return error for invalid ciphertext") + } + }) + } +} + +func TestDifferentKeys(t *testing.T) { + encryptor1 := NewCipher("key1") + encryptor2 := NewCipher("key2") + + plaintext := "secret message" + + ciphertext, err := encryptor1.Encrypt(plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + _, err = encryptor2.Decrypt(ciphertext) + if err == nil { + t.Error("Decrypt() should fail when using different key") + } +} + +func TestEncryptionUniqueness(t *testing.T) { + encryptor := NewCipher("test-key") + plaintext := "same message" + + ciphertext1, err := encryptor.Encrypt(plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + ciphertext2, err := encryptor.Encrypt(plaintext) + if err != nil { + t.Fatalf("Encrypt() error = %v", err) + } + + if ciphertext1 == ciphertext2 { + t.Error("Encrypt() should produce different ciphertexts for same plaintext (different nonces)") + } + + decrypted1, _ := encryptor.Decrypt(ciphertext1) + decrypted2, _ := encryptor.Decrypt(ciphertext2) + + if decrypted1 != plaintext || decrypted2 != plaintext { + t.Error("Both ciphertexts should decrypt to original plaintext") + } +} diff --git a/internal/pkg/app/app.go b/internal/pkg/app/app.go index 97bd03d..db921d5 100644 --- a/internal/pkg/app/app.go +++ b/internal/pkg/app/app.go @@ -9,16 +9,21 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/redis/go-redis/v9" + "github.com/voidcontests/api/internal/app/distributor" "github.com/voidcontests/api/internal/app/router" "github.com/voidcontests/api/internal/config" + "github.com/voidcontests/api/internal/lib/crypto" "github.com/voidcontests/api/internal/lib/logger/prettyslog" "github.com/voidcontests/api/internal/lib/logger/sl" broker "github.com/voidcontests/api/internal/storage/broker/redis" "github.com/voidcontests/api/internal/storage/repository" "github.com/voidcontests/api/internal/storage/repository/postgres" "github.com/voidcontests/api/internal/version" + "github.com/voidcontests/api/pkg/scheduler" + "github.com/voidcontests/api/pkg/ton" ) type App struct { @@ -74,7 +79,17 @@ func (a *App) Run() { repo := repository.New(db) brok := broker.New(rc) - r := router.New(a.config, repo, brok) + tonc, err := ton.NewClient(ctx, &a.config.Ton) + if err != nil { + slog.Error("ton: could not establish connection", sl.Err(err)) + return + } + + slog.Info("ton: ok", slog.Bool("is_testnet", a.config.Ton.IsTestnet)) + + cipher := crypto.NewCipher(a.config.Security.WalletEncryptKey) + + r := router.New(a.config, repo, brok, tonc, cipher) server := &http.Server{ Addr: a.config.Server.Address, @@ -96,6 +111,15 @@ func (a *App) Run() { slog.Info("api: started", slog.String("address", server.Addr)) + interval := 1 * time.Minute + task := distributor.New(repo, tonc, cipher) + scheduler := scheduler.New(interval, task) + + go func() { + scheduler.Start(ctx) + defer scheduler.Stop() + }() + quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT) <-quit diff --git a/internal/storage/models/award/award.go b/internal/storage/models/award/award.go new file mode 100644 index 0000000..2862af0 --- /dev/null +++ b/internal/storage/models/award/award.go @@ -0,0 +1,7 @@ +package award + +const ( + No = "no" + Sponsored = "sponsored" + Pool = "pool" +) diff --git a/internal/storage/models/models.go b/internal/storage/models/models.go index ccefcbf..a745b04 100644 --- a/internal/storage/models/models.go +++ b/internal/storage/models/models.go @@ -10,55 +10,94 @@ const ( ) type User struct { - ID int32 `db:"id"` + ID int `db:"id"` Username string `db:"username"` PasswordHash string `db:"password_hash"` - RoleID int32 `db:"role_id"` + RoleID int `db:"role_id"` + Address string `db:"address"` CreatedAt time.Time `db:"created_at"` } +type UpdateUserParams struct { + Username *string + Address *string +} + type Role struct { - ID int32 `db:"id"` + ID int `db:"id"` Name string `db:"name"` - CreatedProblemsLimit int32 `db:"created_problems_limit"` - CreatedContestsLimit int32 `db:"created_contests_limit"` + CreatedProblemsLimit int `db:"created_problems_limit"` + CreatedContestsLimit int `db:"created_contests_limit"` IsDefault bool `db:"is_default"` CreatedAt time.Time `db:"created_at"` } type Contest struct { - ID int32 `db:"id"` - CreatorID int32 `db:"creator_id"` - CreatorUsername string `db:"creator_username"` - Title string `db:"title"` - Description string `db:"description"` - StartTime time.Time `db:"start_time"` - EndTime time.Time `db:"end_time"` - DurationMins int32 `db:"duration_mins"` - MaxEntries int32 `db:"max_entries"` - AllowLateJoin bool `db:"allow_late_join"` - Participants int32 `db:"participants"` - CreatedAt time.Time `db:"created_at"` + ID int `db:"id"` + CreatorID int `db:"creator_id"` + CreatorUsername string `db:"creator_username"` + CreatorAddress string `db:"creator_address"` + Title string `db:"title"` + Description string `db:"description"` + AwardType string `db:"award_type"` + EntryPriceTonNanos uint64 `db:"entry_price_ton_nanos"` + StartTime time.Time `db:"start_time"` + EndTime time.Time `db:"end_time"` + DurationMins int `db:"duration_mins"` + MaxEntries int `db:"max_entries"` + AllowLateJoin bool `db:"allow_late_join"` + ParticipantsCount int `db:"participants"` + WalletID *int `db:"wallet_id"` + DistributionPaymentID *int `db:"distribution_payment_id"` + CreatedAt time.Time `db:"created_at"` +} + +type ContestFilters struct { + CreatorID int + Title string +} + +type ProblemCharcode struct { + ProblemID int + Charcode string +} + +type Wallet struct { + ID int `db:"id"` + Address string `db:"address"` + MnemonicEncrypted string `db:"mnemonic_encrypted"` + CreatedAt time.Time `db:"created_at"` +} + +type Payment struct { + ID int `db:"id"` + TxHash string `db:"tx_hash"` + FromAddress string `db:"from_address"` + ToAddress string `db:"to_address"` + AmountTonNanos uint64 `db:"amount_ton_nanos"` + IsIncoming bool `db:"is_incoming"` + CreatedAt time.Time `db:"created_at"` } type Problem struct { - ID int32 `db:"id"` + ID int `db:"id"` Charcode string `db:"charcode"` - WriterID int32 `db:"writer_id"` + WriterID int `db:"writer_id"` WriterUsername string `db:"writer_username"` + WriterAddress string `db:"writer_address"` Title string `db:"title"` Statement string `db:"statement"` Difficulty string `db:"difficulty"` - TimeLimitMS int32 `db:"time_limit_ms"` - MemoryLimitMB int32 `db:"memory_limit_mb"` + TimeLimitMS int `db:"time_limit_ms"` + MemoryLimitMB int `db:"memory_limit_mb"` Checker string `db:"checker"` CreatedAt time.Time `db:"created_at"` } type TestCase struct { - ID int32 `db:"id"` - ProblemID int32 `db:"problem_id"` - Ordinal int32 `db:"ordinal"` + ID int `db:"id"` + ProblemID int `db:"problem_id"` + Ordinal int `db:"ordinal"` Input string `db:"input"` Output string `db:"output"` IsExample bool `db:"is_example"` @@ -71,16 +110,20 @@ type TestCaseDTO struct { } type Entry struct { - ID int32 `db:"id"` - ContestID int32 `db:"contest_id"` - UserID int32 `db:"user_id"` + ID int `db:"id"` + ContestID int `db:"contest_id"` + UserID int `db:"user_id"` + PaymentID *int `db:"payment_id"` CreatedAt time.Time `db:"created_at"` } type Submission struct { - ID int32 `db:"id"` - EntryID int32 `db:"entry_id"` - ProblemID int32 `db:"problem_id"` + ID int `db:"id"` + EntryID int `db:"entry_id"` + ProblemID int `db:"problem_id"` + ContestID int `db:"contest_id"` + UserID int `db:"user_id"` + Username string `db:"username"` Status string `db:"status"` Verdict string `db:"verdict"` Code string `db:"code"` @@ -89,25 +132,25 @@ type Submission struct { } type TestingReport struct { - ID int32 `db:"id"` - SubmissionID int32 `db:"submission_id"` - PassedTestsCount int32 `db:"passed_tests_count"` - TotalTestsCount int32 `db:"total_tests_count"` - FirstFailedTestID *int32 `db:"first_failed_test_id"` + ID int `db:"id"` + SubmissionID int `db:"submission_id"` + PassedTestsCount int `db:"passed_tests_count"` + TotalTestsCount int `db:"total_tests_count"` + FirstFailedTestID *int `db:"first_failed_test_id"` FirstFailedTestOutput *string `db:"first_failed_test_output"` Stderr string `db:"stderr"` CreatedAt time.Time `db:"created_at"` } -type LeaderboardEntry struct { - UserID int32 `db:"user_id" json:"user_id"` +type ScoresEntry struct { + UserID int `db:"user_id" json:"user_id"` Username string `db:"username" json:"username"` Points int `db:"points" json:"points"` } type FailedTest struct { - ID int32 `db:"id"` - SubmissionID int32 `db:"submission_id"` + ID int `db:"id"` + SubmissionID int `db:"submission_id"` Input string `db:"input"` ExpectedOutput string `db:"expected_output"` ActualOutput string `db:"actual_output"` diff --git a/internal/storage/repository/postgres/contest/contest.go b/internal/storage/repository/postgres/contest/contest.go index 329d58f..4ea1bbc 100644 --- a/internal/storage/repository/postgres/contest/contest.go +++ b/internal/storage/repository/postgres/contest/contest.go @@ -7,99 +7,92 @@ import ( "time" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository/postgres" ) const defaultLimit = 20 type Postgres struct { - pool *pgxpool.Pool + conn postgres.Transactor } -func New(pool *pgxpool.Pool) *Postgres { - return &Postgres{pool} +func New(txr postgres.Transactor) *Postgres { + return &Postgres{conn: txr} } -func (p *Postgres) Create(ctx context.Context, creatorID int32, title, description string, startTime, endTime time.Time, durationMins, maxEntries int32, allowLateJoin bool) (int32, error) { - var id int32 - query := `INSERT INTO contests (creator_id, title, description, start_time, end_time, duration_mins, max_entries, allow_late_join) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id` - err := p.pool.QueryRow(ctx, query, creatorID, title, description, startTime, endTime, durationMins, maxEntries, allowLateJoin).Scan(&id) - return id, err -} +func (p *Postgres) Create(ctx context.Context, creatorID int, title, desc, awardType string, entryPriceTonNanos uint64, startTime, endTime time.Time, durationMins, maxEntries int, allowLateJoin bool, problems []models.ProblemCharcode, walletID *int) (int, error) { + var contestID int + var err error -func (p *Postgres) CreateWithProblemIDs(ctx context.Context, creatorID int32, title, desc string, startTime, endTime time.Time, durationMins, maxEntries int32, allowLateJoin bool, problemIDs []int32) (int32, error) { - charcodes := "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - if len(problemIDs) > len(charcodes) { - return 0, fmt.Errorf("not enough charcodes for the number of problems") - } + err = p.conn.QueryRow(ctx, ` +INSERT INTO contests (creator_id, title, description, award_type, entry_price_ton_nanos, start_time, end_time, duration_mins, max_entries, allow_late_join, wallet_id) +VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id +`, creatorID, title, desc, awardType, entryPriceTonNanos, startTime, endTime, durationMins, maxEntries, allowLateJoin, walletID).Scan(&contestID) - tx, err := p.pool.Begin(ctx) - if err != nil { - return 0, fmt.Errorf("begin transaction: %w", err) - } - defer tx.Rollback(ctx) - - var contestID int32 - err = tx.QueryRow(ctx, ` - INSERT INTO contests - (creator_id, title, description, start_time, end_time, duration_mins, max_entries, allow_late_join) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - RETURNING id - `, creatorID, title, desc, startTime, endTime, durationMins, maxEntries, allowLateJoin).Scan(&contestID) if err != nil { return 0, fmt.Errorf("insert contest failed: %w", err) } - batch := &pgx.Batch{} - for i, pid := range problemIDs { - batch.Queue(` - INSERT INTO contest_problems (contest_id, problem_id, charcode) - VALUES ($1, $2, $3) - `, contestID, pid, string(charcodes[i])) - } + if len(problems) > 0 { + batch := &pgx.Batch{} + for _, p := range problems { + batch.Queue(`INSERT INTO contest_problems (contest_id, problem_id, charcode) VALUES ($1, $2, $3)`, + contestID, p.ProblemID, p.Charcode) + } - br := tx.SendBatch(ctx, batch) + br := p.conn.SendBatch(ctx, batch) - for i := 0; i < len(problemIDs); i++ { - if _, err := br.Exec(); err != nil { - br.Close() - return 0, fmt.Errorf("insert contest_problem %d failed: %w", i, err) + for i := 0; i < len(problems); i++ { + if _, err := br.Exec(); err != nil { + br.Close() + return 0, fmt.Errorf("insert contest_problem %d (problem_id=%d, charcode=%s) failed: %w", i, problems[i].ProblemID, problems[i].Charcode, err) + } } - } - if err := br.Close(); err != nil { - return 0, fmt.Errorf("batch close failed: %w", err) - } - - if err := tx.Commit(ctx); err != nil { - return 0, fmt.Errorf("commit failed: %w", err) + if err := br.Close(); err != nil { + return 0, fmt.Errorf("batch close failed: %w", err) + } } return contestID, nil } -func (p *Postgres) GetByID(ctx context.Context, contestID int32) (models.Contest, error) { +func (p *Postgres) GetByID(ctx context.Context, contestID int) (models.Contest, error) { var contest models.Contest - query := `SELECT contests.*, users.username AS creator_username, COUNT(entries.id) AS participants - FROM contests - JOIN users ON users.id = contests.creator_id - LEFT JOIN entries ON entries.contest_id = contests.id - WHERE contests.id = $1 - GROUP BY contests.id, users.username` - err := p.pool.QueryRow(ctx, query, contestID).Scan(&contest.ID, &contest.CreatorID, &contest.Title, &contest.Description, &contest.StartTime, &contest.EndTime, &contest.DurationMins, &contest.MaxEntries, &contest.AllowLateJoin, &contest.CreatedAt, &contest.CreatorUsername, &contest.Participants) + query := ` +SELECT + id, creator_id, creator_username, creator_address, title, description, award_type, entry_price_ton_nanos, start_time, end_time, duration_mins, + max_entries, allow_late_join, wallet_id, distribution_payment_id, participants_count, created_at +FROM contests_view +WHERE id = $1` + err := p.conn.QueryRow(ctx, query, contestID).Scan( + &contest.ID, &contest.CreatorID, &contest.CreatorUsername, &contest.CreatorAddress, &contest.Title, &contest.Description, &contest.AwardType, &contest.EntryPriceTonNanos, &contest.StartTime, + &contest.EndTime, &contest.DurationMins, &contest.MaxEntries, &contest.AllowLateJoin, + &contest.WalletID, &contest.DistributionPaymentID, &contest.ParticipantsCount, &contest.CreatedAt) return contest, err } -func (p *Postgres) GetProblemset(ctx context.Context, contestID int32) ([]models.Problem, error) { - query := `SELECT cp.charcode, p.*, u.username AS writer_username - FROM problems p - JOIN contest_problems cp ON p.id = cp.problem_id - JOIN users u ON u.id = p.writer_id - WHERE cp.contest_id = $1 ORDER BY charcode ASC` +func (p *Postgres) GetWallet(ctx context.Context, walletID int) (models.Wallet, error) { + var wallet models.Wallet + query := ` +SELECT + w.id, w.address, w.mnemonic_encrypted, w.created_at +FROM wallets w +WHERE w.id = $1` + err := p.conn.QueryRow(ctx, query, walletID).Scan(&wallet.ID, &wallet.Address, &wallet.MnemonicEncrypted, &wallet.CreatedAt) + return wallet, err +} - rows, err := p.pool.Query(ctx, query, contestID) +func (p *Postgres) GetProblemset(ctx context.Context, contestID int) ([]models.Problem, error) { + query := ` +SELECT + problem_id, charcode, writer_id, writer_username, writer_address, title, statement, + difficulty, time_limit_ms, memory_limit_mb, checker, created_at +FROM contest_problems_view +WHERE contest_id = $1 ORDER BY charcode ASC` + + rows, err := p.conn.Query(ctx, query, contestID) if err != nil { return nil, err } @@ -108,7 +101,7 @@ func (p *Postgres) GetProblemset(ctx context.Context, contestID int32) ([]models var problems []models.Problem for rows.Next() { var problem models.Problem - if err := rows.Scan(&problem.Charcode, &problem.ID, &problem.WriterID, &problem.Title, &problem.Statement, &problem.Difficulty, &problem.TimeLimitMS, &problem.MemoryLimitMB, &problem.Checker, &problem.CreatedAt, &problem.WriterUsername); err != nil { + if err := rows.Scan(&problem.ID, &problem.Charcode, &problem.WriterID, &problem.WriterUsername, &problem.WriterAddress, &problem.Title, &problem.Statement, &problem.Difficulty, &problem.TimeLimitMS, &problem.MemoryLimitMB, &problem.Checker, &problem.CreatedAt); err != nil { return nil, err } problems = append(problems, problem) @@ -116,26 +109,67 @@ func (p *Postgres) GetProblemset(ctx context.Context, contestID int32) ([]models return problems, nil } -func (p *Postgres) ListAll(ctx context.Context, limit int, offset int) (contests []models.Contest, total int, err error) { +func (p *Postgres) ListAll(ctx context.Context, limit int, offset int, filters models.ContestFilters) (contests []models.Contest, total int, err error) { if limit < 0 { limit = defaultLimit } batch := &pgx.Batch{} - batch.Queue(` - SELECT contests.*, users.username AS creator_username, COUNT(entries.id) AS participants - FROM contests - JOIN users ON users.id = contests.creator_id - LEFT JOIN entries ON entries.contest_id = contests.id - WHERE contests.end_time >= now() - GROUP BY contests.id, users.username - ORDER BY contests.id ASC - LIMIT $1 OFFSET $2 - `, limit, offset) - batch.Queue(`SELECT COUNT(*) FROM contests WHERE contests.end_time >= now()`) + whereClauses := []string{"end_time >= now()"} + queryArgs := []interface{}{limit, offset} + countArgs := []interface{}{} + paramIndex := 3 - br := p.pool.SendBatch(ctx, batch) + if filters.CreatorID != 0 { + whereClauses = append(whereClauses, fmt.Sprintf("creator_id = $%d", paramIndex)) + queryArgs = append(queryArgs, filters.CreatorID) + countArgs = append(countArgs, filters.CreatorID) + paramIndex++ + } + + if filters.Title != "" { + whereClauses = append(whereClauses, fmt.Sprintf("LOWER(title) LIKE LOWER($%d)", paramIndex)) + queryArgs = append(queryArgs, "%"+filters.Title+"%") + countArgs = append(countArgs, "%"+filters.Title+"%") + paramIndex++ + } + + whereClause := "WHERE " + strings.Join(whereClauses, " AND ") + + query := fmt.Sprintf(` +SELECT + id, creator_id, creator_username, creator_address, title, description, award_type, entry_price_ton_nanos, start_time, end_time, duration_mins, max_entries, + allow_late_join, wallet_id, distribution_payment_id, participants_count, created_at +FROM contests_view +%s +ORDER BY id ASC +LIMIT $1 OFFSET $2 + `, whereClause) + + batch.Queue(query, queryArgs...) + + countWhereClauses := []string{"end_time >= now()"} + if filters.CreatorID != 0 { + countWhereClauses = append(countWhereClauses, "creator_id = $1") + } + if filters.Title != "" { + countParamIndex := 1 + if filters.CreatorID != 0 { + countParamIndex = 2 + } + countWhereClauses = append(countWhereClauses, fmt.Sprintf("LOWER(title) LIKE LOWER($%d)", countParamIndex)) + } + countWhereClause := strings.Join(countWhereClauses, " AND ") + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM contests_view WHERE %s", countWhereClause) + + if len(countArgs) > 0 { + batch.Queue(countQuery, countArgs...) + } else { + batch.Queue(countQuery) + } + + br := p.conn.SendBatch(ctx, batch) rows, err := br.Query() if err != nil { @@ -147,10 +181,9 @@ func (p *Postgres) ListAll(ctx context.Context, limit int, offset int) (contests for rows.Next() { var c models.Contest if err := rows.Scan( - &c.ID, &c.CreatorID, &c.Title, &c.Description, + &c.ID, &c.CreatorID, &c.CreatorUsername, &c.CreatorAddress, &c.Title, &c.Description, &c.AwardType, &c.EntryPriceTonNanos, &c.StartTime, &c.EndTime, &c.DurationMins, - &c.MaxEntries, &c.AllowLateJoin, &c.CreatedAt, - &c.CreatorUsername, &c.Participants, + &c.MaxEntries, &c.AllowLateJoin, &c.WalletID, &c.DistributionPaymentID, &c.ParticipantsCount, &c.CreatedAt, ); err != nil { rows.Close() br.Close() @@ -172,22 +205,21 @@ func (p *Postgres) ListAll(ctx context.Context, limit int, offset int) (contests return contests, total, nil } -func (p *Postgres) GetWithCreatorID(ctx context.Context, creatorID int32, limit, offset int) (contests []models.Contest, total int, err error) { +func (p *Postgres) GetWithCreatorID(ctx context.Context, creatorID int, limit, offset int) (contests []models.Contest, total int, err error) { batch := &pgx.Batch{} batch.Queue(` - SELECT contests.*, users.username AS creator_username, COUNT(entries.id) AS participants - FROM contests - JOIN users ON users.id = contests.creator_id - LEFT JOIN entries ON entries.contest_id = contests.id - WHERE contests.creator_id = $1 - GROUP BY contests.id, users.username - ORDER BY contests.id ASC - LIMIT $2 OFFSET $3 +SELECT + id, creator_id, creator_username, creator_address, title, description, award_type, entry_price_ton_nanos, start_time, end_time, duration_mins, max_entries, + allow_late_join, wallet_id, distribution_payment_id, participants_count, created_at +FROM contests_view +WHERE creator_id = $1 +ORDER BY id ASC +LIMIT $2 OFFSET $3 `, creatorID, limit, offset) - batch.Queue(`SELECT COUNT(*) FROM contests WHERE creator_id = $1`, creatorID) + batch.Queue(`SELECT COUNT(*) FROM contests_view WHERE creator_id = $1`, creatorID) - br := p.pool.SendBatch(ctx, batch) + br := p.conn.SendBatch(ctx, batch) rows, err := br.Query() if err != nil { @@ -199,10 +231,9 @@ func (p *Postgres) GetWithCreatorID(ctx context.Context, creatorID int32, limit, for rows.Next() { var c models.Contest if err := rows.Scan( - &c.ID, &c.CreatorID, &c.Title, &c.Description, + &c.ID, &c.CreatorID, &c.CreatorUsername, &c.CreatorAddress, &c.Title, &c.Description, &c.AwardType, &c.EntryPriceTonNanos, &c.StartTime, &c.EndTime, &c.DurationMins, - &c.MaxEntries, &c.AllowLateJoin, &c.CreatedAt, - &c.CreatorUsername, &c.Participants, + &c.MaxEntries, &c.AllowLateJoin, &c.WalletID, &c.DistributionPaymentID, &c.ParticipantsCount, &c.CreatedAt, ); err != nil { rows.Close() br.Close() @@ -224,79 +255,99 @@ func (p *Postgres) GetWithCreatorID(ctx context.Context, creatorID int32, limit, return contests, total, nil } -func (p *Postgres) GetEntriesCount(ctx context.Context, contestID int32) (int32, error) { - var count int32 - err := p.pool.QueryRow(ctx, `SELECT COUNT(*) FROM entries WHERE contest_id = $1`, contestID).Scan(&count) +func (p *Postgres) GetEntriesCount(ctx context.Context, contestID int) (int, error) { + var count int + err := p.conn.QueryRow(ctx, `SELECT COUNT(*) FROM entries WHERE contest_id = $1`, contestID).Scan(&count) return count, err } func (p *Postgres) IsTitleOccupied(ctx context.Context, title string) (bool, error) { var count int - err := p.pool.QueryRow(ctx, `SELECT COUNT(*) FROM contests WHERE LOWER(title) = $1`, strings.ToLower(title)).Scan(&count) + err := p.conn.QueryRow(ctx, `SELECT COUNT(*) FROM contests WHERE LOWER(title) = $1`, strings.ToLower(title)).Scan(&count) return count > 0, err } -func (p *Postgres) GetLeaderboard(ctx context.Context, contestID, limit, offset int) (leaderboard []models.LeaderboardEntry, total int, err error) { - batch := &pgx.Batch{} - batch.Queue(` - SELECT u.id AS user_id, u.username, COALESCE(SUM( - CASE - WHEN p.difficulty = 'easy' THEN 1 - WHEN p.difficulty = 'mid' THEN 3 - WHEN p.difficulty = 'hard' THEN 5 - ELSE 0 - END - ), 0) AS points - FROM users u - JOIN entries e ON u.id = e.user_id - JOIN contests c ON e.contest_id = c.id - LEFT JOIN ( - SELECT DISTINCT entry_id, problem_id - FROM submissions - WHERE verdict = 'ok' - ) s ON e.id = s.entry_id - LEFT JOIN problems p ON s.problem_id = p.id - WHERE c.id = $1 - GROUP BY u.id, u.username - ORDER BY points DESC - LIMIT $2 OFFSET $3 - `, contestID, limit, offset) +func (p *Postgres) GetWinnerID(ctx context.Context, contestID int) (int, error) { + // TODO: tie-breaking rules + query := ` SELECT user_id FROM scores + WHERE contest_id = $1 AND points > 0 ORDER BY points DESC LIMIT 1` - batch.Queue(` - SELECT COUNT(DISTINCT u.id) - FROM users u - JOIN entries e ON u.id = e.user_id - WHERE e.contest_id = $1 - `, contestID) + var userID int + err := p.conn.QueryRow(ctx, query, contestID).Scan(&userID) + return userID, err +} - br := p.pool.SendBatch(ctx, batch) +func (p *Postgres) GetScores(ctx context.Context, contestID, limit, offset int) (scores []models.ScoresEntry, total int, err error) { + query := ` + SELECT user_id, username, points, COUNT(*) OVER() AS total + FROM scores + WHERE contest_id = $1 + ORDER BY points DESC + LIMIT $2 OFFSET $3 + ` - rows, err := br.Query() + rows, err := p.conn.Query(ctx, query, contestID, limit, offset) if err != nil { - br.Close() - return nil, 0, fmt.Errorf("leaderboard query failed: %w", err) + return nil, 0, fmt.Errorf("scores query failed: %w", err) } + defer rows.Close() - leaderboard = make([]models.LeaderboardEntry, 0) + scores = make([]models.ScoresEntry, 0) for rows.Next() { - var entry models.LeaderboardEntry - if err := rows.Scan(&entry.UserID, &entry.Username, &entry.Points); err != nil { - rows.Close() - br.Close() + var entry models.ScoresEntry + if err := rows.Scan(&entry.UserID, &entry.Username, &entry.Points, &total); err != nil { return nil, 0, err } - leaderboard = append(leaderboard, entry) + scores = append(scores, entry) } - rows.Close() - if err := br.QueryRow().Scan(&total); err != nil { - br.Close() - return nil, 0, fmt.Errorf("total count query failed: %w", err) + if err := rows.Err(); err != nil { + return nil, 0, err } - if err := br.Close(); err != nil { - return nil, 0, err + return scores, total, nil +} + +func (p *Postgres) SetDistributionPaymentID(ctx context.Context, contestID int, paymentID int) error { + query := `UPDATE contests SET distribution_payment_id = $1 WHERE id = $2` + _, err := p.conn.Exec(ctx, query, paymentID, contestID) + if err != nil { + return fmt.Errorf("set distribution_payment_id failed: %w", err) + } + return nil +} + +func (p *Postgres) GetWithUndistributedAwards(ctx context.Context) ([]models.Contest, error) { + query := ` +SELECT + id, creator_id, creator_username, creator_address, title, description, award_type, entry_price_ton_nanos, start_time, end_time, duration_mins, + max_entries, allow_late_join, wallet_id, distribution_payment_id, participants_count, created_at +FROM contests_view +WHERE distribution_payment_id IS NULL AND end_time < now() AND award_type <> 'no' +ORDER BY end_time ASC` + + rows, err := p.conn.Query(ctx, query) + if err != nil { + return nil, fmt.Errorf("query contests with undistributed awards failed: %w", err) + } + defer rows.Close() + + contests := make([]models.Contest, 0) + for rows.Next() { + var c models.Contest + if err := rows.Scan( + &c.ID, &c.CreatorID, &c.CreatorUsername, &c.CreatorAddress, &c.Title, &c.Description, &c.AwardType, &c.EntryPriceTonNanos, + &c.StartTime, &c.EndTime, &c.DurationMins, + &c.MaxEntries, &c.AllowLateJoin, &c.WalletID, &c.DistributionPaymentID, &c.ParticipantsCount, &c.CreatedAt, + ); err != nil { + return nil, fmt.Errorf("scan failed: %w", err) + } + contests = append(contests, c) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("rows iteration failed: %w", err) } - return leaderboard, total, nil + return contests, nil } diff --git a/internal/storage/repository/postgres/entry/entry.go b/internal/storage/repository/postgres/entry/entry.go index a00bc9f..fbb0e4c 100644 --- a/internal/storage/repository/postgres/entry/entry.go +++ b/internal/storage/repository/postgres/entry/entry.go @@ -3,38 +3,39 @@ package entry import ( "context" - "github.com/jackc/pgx/v5/pgxpool" "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository/postgres" ) type Postgres struct { - pool *pgxpool.Pool + conn postgres.Transactor } -func New(pool *pgxpool.Pool) *Postgres { - return &Postgres{pool} +func New(conn postgres.Transactor) *Postgres { + return &Postgres{conn} } -func (p *Postgres) Create(ctx context.Context, contestID int32, userID int32) (int, error) { +func (p *Postgres) Create(ctx context.Context, contestID int, userID int) (int, error) { query := `INSERT INTO entries (contest_id, user_id) VALUES ($1, $2) RETURNING id` var id int - err := p.pool.QueryRow(ctx, query, contestID, userID).Scan(&id) + err := p.conn.QueryRow(ctx, query, contestID, userID).Scan(&id) if err != nil { return 0, err } return id, nil } -func (p *Postgres) Get(ctx context.Context, contestID int32, userID int32) (models.Entry, error) { - query := `SELECT id, contest_id, user_id, created_at FROM entries +func (p *Postgres) Get(ctx context.Context, contestID int, userID int) (models.Entry, error) { + query := `SELECT id, contest_id, user_id, payment_id, created_at FROM entries WHERE contest_id = $1 AND user_id = $2` var entry models.Entry - err := p.pool.QueryRow(ctx, query, contestID, userID).Scan( + err := p.conn.QueryRow(ctx, query, contestID, userID).Scan( &entry.ID, &entry.ContestID, &entry.UserID, + &entry.PaymentID, &entry.CreatedAt, ) if err != nil { @@ -42,3 +43,9 @@ func (p *Postgres) Get(ctx context.Context, contestID int32, userID int32) (mode } return entry, nil } + +func (p *Postgres) SetPaymentID(ctx context.Context, entryID int, paymentID int) error { + query := `UPDATE entries SET payment_id = $1 WHERE id = $2` + _, err := p.conn.Exec(ctx, query, paymentID, entryID) + return err +} diff --git a/internal/storage/repository/postgres/payment/payment.go b/internal/storage/repository/postgres/payment/payment.go new file mode 100644 index 0000000..4a56f3a --- /dev/null +++ b/internal/storage/repository/postgres/payment/payment.go @@ -0,0 +1,69 @@ +package payment + +import ( + "context" + "fmt" + + "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository/postgres" +) + +type Postgres struct { + conn postgres.Transactor +} + +func New(conn postgres.Transactor) *Postgres { + return &Postgres{conn: conn} +} + +func (p *Postgres) Create(ctx context.Context, txHash, fromAddress, toAddress string, amountTonNanos uint64, isIncoming bool) (int, error) { + var paymentID int + query := `INSERT INTO payments (tx_hash, from_address, to_address, amount_ton_nanos, is_incoming) + VALUES ($1, $2, $3, $4, $5) RETURNING id` + err := p.conn.QueryRow(ctx, query, txHash, fromAddress, toAddress, amountTonNanos, isIncoming).Scan(&paymentID) + if err != nil { + return 0, fmt.Errorf("insert payment failed: %w", err) + } + + return paymentID, nil +} + +func (p *Postgres) GetByID(ctx context.Context, paymentID int) (models.Payment, error) { + var payment models.Payment + query := `SELECT id, tx_hash, from_address, to_address, amount_ton_nanos, is_incoming, created_at + FROM payments WHERE id = $1` + err := p.conn.QueryRow(ctx, query, paymentID).Scan( + &payment.ID, + &payment.TxHash, + &payment.FromAddress, + &payment.ToAddress, + &payment.AmountTonNanos, + &payment.IsIncoming, + &payment.CreatedAt, + ) + if err != nil { + return models.Payment{}, fmt.Errorf("get payment by id failed: %w", err) + } + + return payment, nil +} + +func (p *Postgres) GetByTxHash(ctx context.Context, txHash string) (models.Payment, error) { + var payment models.Payment + query := `SELECT id, tx_hash, from_address, to_address, amount_ton_nanos, is_incoming, created_at + FROM payments WHERE tx_hash = $1` + err := p.conn.QueryRow(ctx, query, txHash).Scan( + &payment.ID, + &payment.TxHash, + &payment.FromAddress, + &payment.ToAddress, + &payment.AmountTonNanos, + &payment.IsIncoming, + &payment.CreatedAt, + ) + if err != nil { + return models.Payment{}, fmt.Errorf("get payment by tx_hash failed: %w", err) + } + + return payment, nil +} diff --git a/internal/storage/repository/postgres/problem/problem.go b/internal/storage/repository/postgres/problem/problem.go index c0a886c..57da032 100644 --- a/internal/storage/repository/postgres/problem/problem.go +++ b/internal/storage/repository/postgres/problem/problem.go @@ -5,27 +5,21 @@ import ( "fmt" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository/postgres" ) type Postgres struct { - pool *pgxpool.Pool + conn postgres.Transactor } -func New(pool *pgxpool.Pool) *Postgres { - return &Postgres{pool} +func New(conn postgres.Transactor) *Postgres { + return &Postgres{conn: conn} } -func (p *Postgres) CreateWithTCs(ctx context.Context, writerID int32, title, statement, difficulty string, timeLimitMS, memoryLimitMB int, checker string, tcs []models.TestCaseDTO) (int32, error) { - tx, err := p.pool.BeginTx(ctx, pgx.TxOptions{}) - if err != nil { - return 0, fmt.Errorf("tx begin failed: %w", err) - } - defer tx.Rollback(ctx) - - var problemID int32 - err = tx.QueryRow(ctx, ` +func (p *Postgres) Create(ctx context.Context, writerID int, title, statement, difficulty string, timeLimitMS, memoryLimitMB int, checker string) (int, error) { + var problemID int + err := p.conn.QueryRow(ctx, ` INSERT INTO problems (writer_id, title, statement, difficulty, time_limit_ms, memory_limit_mb, checker) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id @@ -34,80 +28,79 @@ func (p *Postgres) CreateWithTCs(ctx context.Context, writerID int32, title, sta return 0, fmt.Errorf("insert problem failed: %w", err) } - if len(tcs) > 0 { - batch := &pgx.Batch{} - for i, tc := range tcs { - batch.Queue(` - INSERT INTO test_cases (problem_id, ordinal, input, output, is_example) - VALUES ($1, $2, $3, $4, $5) - `, problemID, i+1, tc.Input, tc.Output, tc.IsExample) - } + return problemID, nil +} + +func (p *Postgres) AssociateTestCases(ctx context.Context, problemID int, tcs []models.TestCaseDTO) error { + if len(tcs) == 0 { + return nil + } - br := tx.SendBatch(ctx, batch) + batch := &pgx.Batch{} + for i, tc := range tcs { + batch.Queue(` + INSERT INTO test_cases (problem_id, ordinal, input, output, is_example) + VALUES ($1, $2, $3, $4, $5) + `, problemID, i+1, tc.Input, tc.Output, tc.IsExample) + } - for i := 0; i < batch.Len(); i++ { - if _, err := br.Exec(); err != nil { - br.Close() - return 0, fmt.Errorf("insert test case %d failed: %w", i, err) - } - } + br := p.conn.SendBatch(ctx, batch) + defer br.Close() - if err := br.Close(); err != nil { - return 0, fmt.Errorf("batch close failed: %w", err) + for i := 0; i < batch.Len(); i++ { + if _, err := br.Exec(); err != nil { + return fmt.Errorf("insert test case %d (ordinal=%d, is_example=%v) failed: %w", i, i+1, tcs[i].IsExample, err) } } - if err := tx.Commit(ctx); err != nil { - return 0, fmt.Errorf("commit failed: %w", err) + if err := br.Close(); err != nil { + return fmt.Errorf("batch close failed: %w", err) } - return problemID, nil + return nil } -func (p *Postgres) Get(ctx context.Context, contestID int32, charcode string) (models.Problem, error) { - query := `SELECT p.*, cp.charcode, u.username AS writer_username - FROM problems p - JOIN contest_problems cp ON p.id = cp.problem_id - JOIN users u ON u.id = p.writer_id - WHERE cp.contest_id = $1 AND cp.charcode = $2` +func (p *Postgres) Get(ctx context.Context, contestID int, charcode string) (models.Problem, error) { + query := ` +SELECT + problem_id, charcode, writer_id, writer_username, writer_address, title, statement, + difficulty, time_limit_ms, memory_limit_mb, checker, created_at +FROM contest_problems_view +WHERE contest_id = $1 AND charcode = $2` - row := p.pool.QueryRow(ctx, query, contestID, charcode) + row := p.conn.QueryRow(ctx, query, contestID, charcode) var problem models.Problem err := row.Scan( - &problem.ID, &problem.WriterID, &problem.Title, &problem.Statement, + &problem.ID, &problem.Charcode, &problem.WriterID, &problem.WriterUsername, &problem.WriterAddress, &problem.Title, &problem.Statement, &problem.Difficulty, &problem.TimeLimitMS, &problem.MemoryLimitMB, &problem.Checker, &problem.CreatedAt, - &problem.Charcode, &problem.WriterUsername, ) return problem, err } -func (p *Postgres) GetByID(ctx context.Context, problemID int32) (models.Problem, error) { +func (p *Postgres) GetByID(ctx context.Context, problemID int) (models.Problem, error) { query := `SELECT - p.id, p.writer_id, p.title, p.statement, - p.difficulty, p.time_limit_ms, p.memory_limit_mb, p.checker, p.created_at, - u.username AS writer_username - FROM problems p - JOIN users u ON u.id = p.writer_id - WHERE p.id = $1` + id, writer_id, writer_username, writer_address, title, statement, + difficulty, time_limit_ms, memory_limit_mb, checker, created_at + FROM problems_view + WHERE id = $1` - row := p.pool.QueryRow(ctx, query, problemID) + row := p.conn.QueryRow(ctx, query, problemID) var problem models.Problem err := row.Scan( - &problem.ID, &problem.WriterID, &problem.Title, &problem.Statement, + &problem.ID, &problem.WriterID, &problem.WriterUsername, &problem.WriterAddress, &problem.Title, &problem.Statement, &problem.Difficulty, &problem.TimeLimitMS, &problem.MemoryLimitMB, &problem.Checker, &problem.CreatedAt, - &problem.WriterUsername, ) return problem, err } -func (p *Postgres) GetExampleCases(ctx context.Context, problemID int32) ([]models.TestCase, error) { +func (p *Postgres) GetExampleCases(ctx context.Context, problemID int) ([]models.TestCase, error) { query := `SELECT id, problem_id, ordinal, input, output, is_example FROM test_cases WHERE problem_id = $1 AND is_example = true` - rows, err := p.pool.Query(ctx, query, problemID) + rows, err := p.conn.Query(ctx, query, problemID) if err != nil { return nil, err } @@ -125,11 +118,11 @@ func (p *Postgres) GetExampleCases(ctx context.Context, problemID int32) ([]mode return tcs, rows.Err() } -func (p *Postgres) GetTestCaseByID(ctx context.Context, testCaseID int32) (models.TestCase, error) { +func (p *Postgres) GetTestCaseByID(ctx context.Context, testCaseID int) (models.TestCase, error) { query := `SELECT id, problem_id, ordinal, input, output, is_example FROM test_cases WHERE id = $1` var tc models.TestCase - err := p.pool.QueryRow(ctx, query, testCaseID).Scan( + err := p.conn.QueryRow(ctx, query, testCaseID).Scan( &tc.ID, &tc.ProblemID, &tc.Ordinal, @@ -146,9 +139,13 @@ func (p *Postgres) GetTestCaseByID(ctx context.Context, testCaseID int32) (model } func (p *Postgres) GetAll(ctx context.Context) ([]models.Problem, error) { - query := `SELECT problems.*, users.username AS writer_username FROM problems JOIN users ON users.id = problems.writer_id` + query := ` +SELECT + id, writer_id, writer_username, writer_address, title, statement, difficulty, time_limit_ms, + memory_limit_mb, checker, created_at +FROM problems_view` - rows, err := p.pool.Query(ctx, query) + rows, err := p.conn.Query(ctx, query) if err != nil { return nil, err } @@ -158,8 +155,8 @@ func (p *Postgres) GetAll(ctx context.Context) ([]models.Problem, error) { for rows.Next() { var p models.Problem if err := rows.Scan( - &p.ID, &p.WriterID, &p.Title, &p.Statement, &p.Difficulty, - &p.TimeLimitMS, &p.MemoryLimitMB, &p.Checker, &p.CreatedAt, &p.WriterUsername, + &p.ID, &p.WriterID, &p.WriterUsername, &p.WriterAddress, &p.Title, &p.Statement, &p.Difficulty, + &p.TimeLimitMS, &p.MemoryLimitMB, &p.Checker, &p.CreatedAt, ); err != nil { return nil, err } @@ -169,23 +166,22 @@ func (p *Postgres) GetAll(ctx context.Context) ([]models.Problem, error) { return problems, rows.Err() } -func (p *Postgres) GetWithWriterID(ctx context.Context, writerID int32, limit, offset int) (problems []models.Problem, total int, err error) { +func (p *Postgres) GetWithWriterID(ctx context.Context, writerID int, limit, offset int) (problems []models.Problem, total int, err error) { batch := &pgx.Batch{} batch.Queue(` - SELECT problems.*, users.username AS writer_username - FROM problems - JOIN users ON users.id = problems.writer_id - WHERE writer_id = $1 - ORDER BY problems.id ASC - LIMIT $2 OFFSET $3 +SELECT + id, writer_id, writer_username, writer_address, title, statement, difficulty, time_limit_ms, + memory_limit_mb, checker, created_at +FROM problems_view +WHERE writer_id = $1 +ORDER BY id ASC +LIMIT $2 OFFSET $3 `, writerID, limit, offset) - batch.Queue(` - SELECT COUNT(*) FROM problems WHERE writer_id = $1 - `, writerID) + batch.Queue(`SELECT COUNT(*) FROM problems_view WHERE writer_id = $1`, writerID) - br := p.pool.SendBatch(ctx, batch) + br := p.conn.SendBatch(ctx, batch) rows, err := br.Query() if err != nil { @@ -197,8 +193,8 @@ func (p *Postgres) GetWithWriterID(ctx context.Context, writerID int32, limit, o for rows.Next() { var p models.Problem if err := rows.Scan( - &p.ID, &p.WriterID, &p.Title, &p.Statement, &p.Difficulty, - &p.TimeLimitMS, &p.MemoryLimitMB, &p.Checker, &p.CreatedAt, &p.WriterUsername, + &p.ID, &p.WriterID, &p.WriterUsername, &p.WriterAddress, &p.Title, &p.Statement, &p.Difficulty, + &p.TimeLimitMS, &p.MemoryLimitMB, &p.Checker, &p.CreatedAt, ); err != nil { rows.Close() br.Close() diff --git a/internal/storage/repository/postgres/submission/submission.go b/internal/storage/repository/postgres/submission/submission.go index 6a9f9df..51ac9ee 100644 --- a/internal/storage/repository/postgres/submission/submission.go +++ b/internal/storage/repository/postgres/submission/submission.go @@ -3,10 +3,12 @@ package submission import ( "context" "database/sql" + "errors" "fmt" - "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5" "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository/postgres" ) const ( @@ -14,23 +16,35 @@ const ( ) type Postgres struct { - pool *pgxpool.Pool + conn postgres.Transactor } -func New(pool *pgxpool.Pool) *Postgres { - return &Postgres{pool} +func New(conn postgres.Transactor) *Postgres { + return &Postgres{conn} } -func (p *Postgres) Create(ctx context.Context, entryID int32, problemID int32, code string, language string) (models.Submission, error) { - query := `INSERT INTO submissions (entry_id, problem_id, code, language) +func (p *Postgres) Create(ctx context.Context, entryID int, problemID int, code string, language string) (models.Submission, error) { + var submissionID int + insertQuery := `INSERT INTO submissions (entry_id, problem_id, code, language) VALUES ($1, $2, $3, $4) - RETURNING id, entry_id, problem_id, status, verdict, code, language, created_at` + RETURNING id` + + err := p.conn.QueryRow(ctx, insertQuery, entryID, problemID, code, language).Scan(&submissionID) + if err != nil { + return models.Submission{}, fmt.Errorf("insert failed: %w", err) + } + + selectQuery := `SELECT id, entry_id, contest_id, problem_id, user_id, username, status, verdict, code, language, created_at + FROM submissions_view WHERE id = $1` var submission models.Submission - err := p.pool.QueryRow(ctx, query, entryID, problemID, code, language).Scan( + err = p.conn.QueryRow(ctx, selectQuery, submissionID).Scan( &submission.ID, &submission.EntryID, + &submission.ContestID, &submission.ProblemID, + &submission.UserID, + &submission.Username, &submission.Status, &submission.Verdict, &submission.Code, @@ -41,21 +55,19 @@ func (p *Postgres) Create(ctx context.Context, entryID int32, problemID int32, c return submission, err } -func (p *Postgres) GetProblemStatus(ctx context.Context, entryID int32, problemID int32) (string, error) { +func (p *Postgres) GetProblemStatus(ctx context.Context, entryID int, problemID int) (string, error) { query := ` - SELECT - CASE - WHEN COUNT(*) FILTER (WHERE s.verdict = 'ok') > 0 THEN 'accepted' - WHEN COUNT(*) > 0 THEN 'tried' - ELSE NULL - END AS status - FROM submissions s - WHERE s.entry_id = $1 AND s.problem_id = $2 + SELECT status + FROM problem_statuses + WHERE entry_id = $1 AND problem_id = $2 ` var status sql.NullString - err := p.pool.QueryRow(ctx, query, entryID, problemID).Scan(&status) + err := p.conn.QueryRow(ctx, query, entryID, problemID).Scan(&status) if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", nil + } return "", fmt.Errorf("query failed: %w", err) } @@ -65,30 +77,25 @@ func (p *Postgres) GetProblemStatus(ctx context.Context, entryID int32, problemI return "", nil } -func (p *Postgres) GetProblemStatuses(ctx context.Context, entryID int32) (map[int32]string, error) { +func (p *Postgres) GetProblemStatuses(ctx context.Context, entryID int) (map[int]string, error) { query := ` SELECT - s.problem_id, - CASE - WHEN COUNT(*) FILTER (WHERE s.verdict = 'ok') > 0 THEN 'accepted' - WHEN COUNT(*) > 0 THEN 'tried' - ELSE NULL - END AS status - FROM submissions s - WHERE s.entry_id = $1 - GROUP BY s.problem_id + problem_id, + status + FROM problem_statuses + WHERE entry_id = $1 ` - rows, err := p.pool.Query(ctx, query, entryID) + rows, err := p.conn.Query(ctx, query, entryID) if err != nil { return nil, fmt.Errorf("query failed: %w", err) } defer rows.Close() - statuses := make(map[int32]string) + statuses := make(map[int]string) for rows.Next() { - var problemID int32 + var problemID int var status sql.NullString if err := rows.Scan(&problemID, &status); err != nil { @@ -109,15 +116,18 @@ func (p *Postgres) GetProblemStatuses(ctx context.Context, entryID int32) (map[i return statuses, nil } -func (p *Postgres) GetByID(ctx context.Context, submissionID int32) (models.Submission, error) { - query := `SELECT s.id, s.entry_id, s.problem_id, s.status, s.verdict, s.code, s.language, s.created_at - FROM submissions s WHERE s.id = $1` +func (p *Postgres) GetByID(ctx context.Context, submissionID int) (models.Submission, error) { + query := `SELECT id, entry_id, contest_id, problem_id, user_id, username, status, verdict, code, language, created_at + FROM submissions_view WHERE id = $1` var s models.Submission - err := p.pool.QueryRow(ctx, query, submissionID).Scan( + err := p.conn.QueryRow(ctx, query, submissionID).Scan( &s.ID, &s.EntryID, + &s.ContestID, &s.ProblemID, + &s.UserID, + &s.Username, &s.Status, &s.Verdict, &s.Code, @@ -128,22 +138,20 @@ func (p *Postgres) GetByID(ctx context.Context, submissionID int32) (models.Subm return s, err } -func (p *Postgres) ListByProblem(ctx context.Context, entryID int32, charcode string, limit int, offset int) (items []models.Submission, total int, err error) { +func (p *Postgres) ListByProblem(ctx context.Context, entryID int, charcode string, limit int, offset int) (items []models.Submission, total int, err error) { if limit < 0 { limit = defaultLimit } query := ` - SELECT s.id, s.entry_id, s.problem_id, s.status, s.verdict, s.code, s.language, s.created_at, COUNT(*) OVER() as total_count - FROM submissions s - JOIN problems p ON p.id = s.problem_id - JOIN entries e ON s.entry_id = e.id - JOIN contest_problems cp ON cp.contest_id = e.contest_id AND cp.problem_id = s.problem_id + SELECT s.id, s.entry_id, s.contest_id, s.problem_id, s.user_id, s.username, s.status, s.verdict, s.code, s.language, s.created_at, COUNT(*) OVER() as total_count + FROM submissions_view s + JOIN contest_problems cp ON cp.contest_id = s.contest_id AND cp.problem_id = s.problem_id WHERE s.entry_id = $1 AND cp.charcode = $2 ORDER BY s.created_at DESC LIMIT $3 OFFSET $4` - rows, err := p.pool.Query(ctx, query, entryID, charcode, limit, offset) + rows, err := p.conn.Query(ctx, query, entryID, charcode, limit, offset) if err != nil { return nil, 0, fmt.Errorf("query rows failed: %w", err) } @@ -155,7 +163,10 @@ func (p *Postgres) ListByProblem(ctx context.Context, entryID int32, charcode st if err := rows.Scan( &s.ID, &s.EntryID, + &s.ContestID, &s.ProblemID, + &s.UserID, + &s.Username, &s.Status, &s.Verdict, &s.Code, @@ -174,13 +185,13 @@ func (p *Postgres) ListByProblem(ctx context.Context, entryID int32, charcode st return items, total, nil } -func (p *Postgres) GetTestingReport(ctx context.Context, submissionID int32) (models.TestingReport, error) { +func (p *Postgres) GetTestingReport(ctx context.Context, submissionID int) (models.TestingReport, error) { query := `SELECT id, submission_id, passed_tests_count, total_tests_count, first_failed_test_id, first_failed_test_output, stderr, created_at FROM testing_reports WHERE submission_id = $1` var report models.TestingReport - err := p.pool.QueryRow(ctx, query, submissionID).Scan( + err := p.conn.QueryRow(ctx, query, submissionID).Scan( &report.ID, &report.SubmissionID, &report.PassedTestsCount, diff --git a/internal/storage/repository/postgres/txmanager.go b/internal/storage/repository/postgres/txmanager.go new file mode 100644 index 0000000..7543041 --- /dev/null +++ b/internal/storage/repository/postgres/txmanager.go @@ -0,0 +1,49 @@ +package postgres + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" +) + +// TxManager manages database transactions +type TxManager struct { + pool *pgxpool.Pool +} + +// NewTxManager creates a new transaction manager +func NewTxManager(pool *pgxpool.Pool) *TxManager { + return &TxManager{pool: pool} +} + +// Transactor is an interface that can be either a pool or a transaction +type Transactor interface { + Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row + SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults +} + +// WithinTransaction executes a function within a database transaction +// If the function returns an error, the transaction is rolled back +// Otherwise, the transaction is committed +func (tm *TxManager) WithinTransaction(ctx context.Context, fn func(ctx context.Context, tx pgx.Tx) error) error { + tx, err := tm.pool.Begin(ctx) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback(ctx) + + if err := fn(ctx, tx); err != nil { + return err + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + + return nil +} diff --git a/internal/storage/repository/postgres/user/user.go b/internal/storage/repository/postgres/user/user.go index 2ad2375..22843ea 100644 --- a/internal/storage/repository/postgres/user/user.go +++ b/internal/storage/repository/postgres/user/user.go @@ -3,27 +3,28 @@ package user import ( "context" - "github.com/jackc/pgx/v5/pgxpool" "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository/postgres" ) type Postgres struct { - pool *pgxpool.Pool + conn postgres.Transactor } -func New(pool *pgxpool.Pool) *Postgres { - return &Postgres{pool} +func New(conn postgres.Transactor) *Postgres { + return &Postgres{conn} } func (p *Postgres) GetByCredentials(ctx context.Context, username string, passwordHash string) (models.User, error) { var user models.User - query := `SELECT id, username, password_hash, role_id, created_at FROM users WHERE username = $1 AND password_hash = $2` - err := p.pool.QueryRow(ctx, query, username, passwordHash).Scan( + query := `SELECT id, username, password_hash, role_id, address, created_at FROM users WHERE username = $1 AND password_hash = $2` + err := p.conn.QueryRow(ctx, query, username, passwordHash).Scan( &user.ID, &user.Username, &user.PasswordHash, &user.RoleID, + &user.Address, &user.CreatedAt, ) return user, err @@ -35,14 +36,15 @@ func (p *Postgres) Create(ctx context.Context, username string, passwordHash str query := ` INSERT INTO users (username, password_hash, role_id) VALUES ($1, $2, (SELECT id FROM roles WHERE is_default = true LIMIT 1)) - RETURNING id, username, password_hash, role_id, created_at + RETURNING id, username, password_hash, role_id, address, created_at ` - err := p.pool.QueryRow(ctx, query, username, passwordHash).Scan( + err := p.conn.QueryRow(ctx, query, username, passwordHash).Scan( &user.ID, &user.Username, &user.PasswordHash, &user.RoleID, + &user.Address, &user.CreatedAt, ) return user, err @@ -52,7 +54,7 @@ func (p *Postgres) Exists(ctx context.Context, username string) (bool, error) { var count int query := `SELECT COUNT(*) FROM users WHERE username = $1` - err := p.pool.QueryRow(ctx, query, username).Scan(&count) + err := p.conn.QueryRow(ctx, query, username).Scan(&count) if err != nil { return false, err } @@ -60,21 +62,37 @@ func (p *Postgres) Exists(ctx context.Context, username string) (bool, error) { return count > 0, nil } -func (p *Postgres) GetByID(ctx context.Context, id int32) (models.User, error) { +func (p *Postgres) GetByID(ctx context.Context, id int) (models.User, error) { var user models.User - query := `SELECT id, username, password_hash, role_id, created_at FROM users WHERE id = $1` - err := p.pool.QueryRow(ctx, query, id).Scan( + query := `SELECT id, username, password_hash, role_id, address, created_at FROM users WHERE id = $1` + err := p.conn.QueryRow(ctx, query, id).Scan( &user.ID, &user.Username, &user.PasswordHash, &user.RoleID, + &user.Address, &user.CreatedAt, ) return user, err } -func (p *Postgres) GetRole(ctx context.Context, userID int32) (models.Role, error) { +func (p *Postgres) GetByUsername(ctx context.Context, username string) (models.User, error) { + var user models.User + + query := `SELECT id, username, password_hash, role_id, address, created_at FROM users WHERE username = $1` + err := p.conn.QueryRow(ctx, query, username).Scan( + &user.ID, + &user.Username, + &user.PasswordHash, + &user.RoleID, + &user.Address, + &user.CreatedAt, + ) + return user, err +} + +func (p *Postgres) GetRole(ctx context.Context, userID int) (models.Role, error) { var role models.Role query := ` @@ -83,7 +101,7 @@ func (p *Postgres) GetRole(ctx context.Context, userID int32) (models.Role, erro JOIN roles r ON u.role_id = r.id WHERE u.id = $1 ` - err := p.pool.QueryRow(ctx, query, userID).Scan( + err := p.conn.QueryRow(ctx, query, userID).Scan( &role.ID, &role.Name, &role.CreatedProblemsLimit, @@ -94,18 +112,41 @@ func (p *Postgres) GetRole(ctx context.Context, userID int32) (models.Role, erro return role, err } -func (p *Postgres) GetCreatedProblemsCount(ctx context.Context, userID int32) (int, error) { +func (p *Postgres) GetCreatedProblemsCount(ctx context.Context, userID int) (int, error) { var count int query := `SELECT COUNT(*) FROM problems WHERE writer_id = $1` - err := p.pool.QueryRow(ctx, query, userID).Scan(&count) + err := p.conn.QueryRow(ctx, query, userID).Scan(&count) return count, err } -func (p *Postgres) GetCreatedContestsCount(ctx context.Context, userID int32) (int, error) { +func (p *Postgres) GetCreatedContestsCount(ctx context.Context, userID int) (int, error) { var count int query := `SELECT COUNT(*) FROM contests WHERE creator_id = $1` - err := p.pool.QueryRow(ctx, query, userID).Scan(&count) + err := p.conn.QueryRow(ctx, query, userID).Scan(&count) return count, err } + +func (p *Postgres) UpdateUser(ctx context.Context, userID int, params models.UpdateUserParams) (models.User, error) { + var user models.User + + query := ` + UPDATE users + SET + username = COALESCE($2, username), + address = CASE WHEN $3::text IS NOT NULL THEN $3 ELSE address END + WHERE id = $1 + RETURNING id, username, password_hash, role_id, address, created_at + ` + + err := p.conn.QueryRow(ctx, query, userID, params.Username, params.Address).Scan( + &user.ID, + &user.Username, + &user.PasswordHash, + &user.RoleID, + &user.Address, + &user.CreatedAt, + ) + return user, err +} diff --git a/internal/storage/repository/postgres/wallet/wallet.go b/internal/storage/repository/postgres/wallet/wallet.go new file mode 100644 index 0000000..5a0821c --- /dev/null +++ b/internal/storage/repository/postgres/wallet/wallet.go @@ -0,0 +1,27 @@ +package wallet + +import ( + "context" + "fmt" + + "github.com/voidcontests/api/internal/storage/repository/postgres" +) + +type Postgres struct { + conn postgres.Transactor +} + +func New(conn postgres.Transactor) *Postgres { + return &Postgres{conn: conn} +} + +func (p *Postgres) Create(ctx context.Context, address, mnemonic string) (int, error) { + var walletID int + query := `INSERT INTO wallets (address, mnemonic_encrypted) VALUES ($1, $2) RETURNING id` + err := p.conn.QueryRow(ctx, query, address, mnemonic).Scan(&walletID) + if err != nil { + return 0, fmt.Errorf("insert wallet failed: %w", err) + } + + return walletID, nil +} diff --git a/internal/storage/repository/repository.go b/internal/storage/repository/repository.go index c8eda6a..3ec2526 100644 --- a/internal/storage/repository/repository.go +++ b/internal/storage/repository/repository.go @@ -4,13 +4,17 @@ import ( "context" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/voidcontests/api/internal/storage/models" + "github.com/voidcontests/api/internal/storage/repository/postgres" "github.com/voidcontests/api/internal/storage/repository/postgres/contest" "github.com/voidcontests/api/internal/storage/repository/postgres/entry" + "github.com/voidcontests/api/internal/storage/repository/postgres/payment" "github.com/voidcontests/api/internal/storage/repository/postgres/problem" "github.com/voidcontests/api/internal/storage/repository/postgres/submission" "github.com/voidcontests/api/internal/storage/repository/postgres/user" + "github.com/voidcontests/api/internal/storage/repository/postgres/wallet" ) type Repository struct { @@ -19,6 +23,8 @@ type Repository struct { Problem Problem Entry Entry Submission Submission + Payment Payment + TxManager *postgres.TxManager } func New(pool *pgxpool.Pool) *Repository { @@ -28,6 +34,30 @@ func New(pool *pgxpool.Pool) *Repository { Problem: problem.New(pool), Entry: entry.New(pool), Submission: submission.New(pool), + Payment: payment.New(pool), + TxManager: postgres.NewTxManager(pool), + } +} + +type TxRepository struct { + User User + Contest Contest + Problem Problem + Entry Entry + Submission Submission + Wallet Wallet + Payment Payment +} + +func NewTxRepository(tx pgx.Tx) *TxRepository { + return &TxRepository{ + Contest: contest.New(tx), + Wallet: wallet.New(tx), + User: user.New(tx), + Entry: entry.New(tx), + Submission: submission.New(tx), + Problem: problem.New(tx), + Payment: payment.New(tx), } } @@ -35,44 +65,61 @@ type User interface { GetByCredentials(ctx context.Context, username string, passwordHash string) (models.User, error) Create(ctx context.Context, username string, passwordHash string) (models.User, error) Exists(ctx context.Context, username string) (bool, error) - GetByID(ctx context.Context, id int32) (models.User, error) - GetRole(ctx context.Context, userID int32) (models.Role, error) - GetCreatedProblemsCount(ctx context.Context, userID int32) (int, error) - GetCreatedContestsCount(ctx context.Context, userID int32) (int, error) + GetByID(ctx context.Context, id int) (models.User, error) + GetByUsername(ctx context.Context, username string) (models.User, error) + GetRole(ctx context.Context, userID int) (models.Role, error) + GetCreatedProblemsCount(ctx context.Context, userID int) (int, error) + GetCreatedContestsCount(ctx context.Context, userID int) (int, error) + UpdateUser(ctx context.Context, userID int, params models.UpdateUserParams) (models.User, error) } type Contest interface { - Create(ctx context.Context, creatorID int32, title, description string, startTime, endTime time.Time, durationMins, maxEntries int32, allowLateJoin bool) (int32, error) - CreateWithProblemIDs(ctx context.Context, creatorID int32, title, desc string, startTime, endTime time.Time, durationMins, maxEntries int32, allowLateJoin bool, problemIDs []int32) (int32, error) - GetByID(ctx context.Context, contestID int32) (models.Contest, error) - GetProblemset(ctx context.Context, contestID int32) ([]models.Problem, error) - ListAll(ctx context.Context, limit int, offset int) (contests []models.Contest, total int, err error) - GetWithCreatorID(ctx context.Context, creatorID int32, limit, offset int) (contests []models.Contest, total int, err error) - GetEntriesCount(ctx context.Context, contestID int32) (int32, error) + Create(ctx context.Context, creatorID int, title, desc, awardType string, entryPriceTonNanos uint64, startTime, endTime time.Time, durationMins, maxEntries int, allowLateJoin bool, problems []models.ProblemCharcode, walletID *int) (int, error) + GetByID(ctx context.Context, contestID int) (models.Contest, error) + GetProblemset(ctx context.Context, contestID int) ([]models.Problem, error) + ListAll(ctx context.Context, limit int, offset int, filters models.ContestFilters) (contests []models.Contest, total int, err error) + GetWithCreatorID(ctx context.Context, creatorID int, limit, offset int) (contests []models.Contest, total int, err error) + GetEntriesCount(ctx context.Context, contestID int) (int, error) IsTitleOccupied(ctx context.Context, title string) (bool, error) - GetLeaderboard(ctx context.Context, contestID, limit, offset int) (leaderboard []models.LeaderboardEntry, total int, err error) + GetScores(ctx context.Context, contestID, limit, offset int) (scores []models.ScoresEntry, total int, err error) + GetWallet(ctx context.Context, walletID int) (models.Wallet, error) + SetDistributionPaymentID(ctx context.Context, contestID int, paymentID int) error + GetWithUndistributedAwards(ctx context.Context) ([]models.Contest, error) + GetWinnerID(ctx context.Context, contestID int) (int, error) +} + +type Wallet interface { + Create(ctx context.Context, address, mnemonic string) (int, error) } type Problem interface { - CreateWithTCs(ctx context.Context, writerID int32, title string, statement string, difficulty string, timeLimitMS, memoryLimitMB int, checker string, tcs []models.TestCaseDTO) (int32, error) - Get(ctx context.Context, contestID int32, charcode string) (models.Problem, error) - GetByID(ctx context.Context, problemID int32) (models.Problem, error) - GetExampleCases(ctx context.Context, problemID int32) ([]models.TestCase, error) - GetTestCaseByID(ctx context.Context, testCaseID int32) (models.TestCase, error) + Create(ctx context.Context, writerID int, title, statement, difficulty string, timeLimitMS, memoryLimitMB int, checker string) (int, error) + AssociateTestCases(ctx context.Context, problemID int, tcs []models.TestCaseDTO) error + Get(ctx context.Context, contestID int, charcode string) (models.Problem, error) + GetByID(ctx context.Context, problemID int) (models.Problem, error) + GetExampleCases(ctx context.Context, problemID int) ([]models.TestCase, error) + GetTestCaseByID(ctx context.Context, testCaseID int) (models.TestCase, error) GetAll(ctx context.Context) ([]models.Problem, error) - GetWithWriterID(ctx context.Context, writerID int32, limit, offset int) (problems []models.Problem, total int, err error) + GetWithWriterID(ctx context.Context, writerID int, limit, offset int) (problems []models.Problem, total int, err error) } type Entry interface { - Create(ctx context.Context, contestID int32, userID int32) (int, error) - Get(ctx context.Context, contestID int32, userID int32) (models.Entry, error) + Create(ctx context.Context, contestID int, userID int) (int, error) + Get(ctx context.Context, contestID int, userID int) (models.Entry, error) + SetPaymentID(ctx context.Context, entryID int, paymentID int) error } type Submission interface { - Create(ctx context.Context, entryID int32, problemID int32, code string, language string) (models.Submission, error) - GetProblemStatus(ctx context.Context, entryID int32, problemID int32) (string, error) - GetProblemStatuses(ctx context.Context, entryID int32) (map[int32]string, error) - GetByID(ctx context.Context, submissionID int32) (models.Submission, error) - ListByProblem(ctx context.Context, entryID int32, charcode string, limit int, offset int) (items []models.Submission, total int, err error) - GetTestingReport(ctx context.Context, submissionID int32) (models.TestingReport, error) + Create(ctx context.Context, entryID int, problemID int, code string, language string) (models.Submission, error) + GetProblemStatus(ctx context.Context, entryID int, problemID int) (string, error) + GetProblemStatuses(ctx context.Context, entryID int) (map[int]string, error) + GetByID(ctx context.Context, submissionID int) (models.Submission, error) + ListByProblem(ctx context.Context, entryID int, charcode string, limit int, offset int) (items []models.Submission, total int, err error) + GetTestingReport(ctx context.Context, submissionID int) (models.TestingReport, error) +} + +type Payment interface { + Create(ctx context.Context, txHash, fromAddress, toAddress string, amountTonNanos uint64, isIncoming bool) (int, error) + GetByID(ctx context.Context, paymentID int) (models.Payment, error) + GetByTxHash(ctx context.Context, txHash string) (models.Payment, error) } diff --git a/migrations/000001_tables.down.sql b/migrations/000001_tables.down.sql new file mode 100644 index 0000000..5182742 --- /dev/null +++ b/migrations/000001_tables.down.sql @@ -0,0 +1,30 @@ +DROP INDEX IF EXISTS idx_testing_reports_first_failed_test_id; +DROP INDEX IF EXISTS idx_testing_reports_submission_id; +DROP INDEX IF EXISTS idx_submissions_problem_id; +DROP INDEX IF EXISTS idx_submissions_entry_id; +DROP INDEX IF EXISTS idx_entries_payment_id; +DROP INDEX IF EXISTS idx_entries_user_id; +DROP INDEX IF EXISTS idx_entries_contest_id; +DROP INDEX IF EXISTS idx_test_cases_problem_id; +DROP INDEX IF EXISTS idx_problems_writer_id; +DROP INDEX IF EXISTS unique_contest_wallet_id; +DROP INDEX IF EXISTS idx_contests_wallet_id; +DROP INDEX IF EXISTS idx_contests_distribution_payment_id; +DROP INDEX IF EXISTS idx_contests_creator_id; +DROP INDEX IF EXISTS unique_user_address; +DROP INDEX IF EXISTS idx_users_role_id; + +DROP TABLE IF EXISTS testing_reports; +DROP TABLE IF EXISTS submissions; +DROP TABLE IF EXISTS entries; +DROP TABLE IF EXISTS contest_problems; +DROP TABLE IF EXISTS test_cases; +DROP TABLE IF EXISTS problems; +DROP TABLE IF EXISTS contests; +DROP TABLE IF EXISTS payments; +DROP TABLE IF EXISTS wallets; +DROP TABLE IF EXISTS users; +DROP TABLE IF EXISTS roles; + +DROP TYPE IF EXISTS difficulty; +DROP TYPE IF EXISTS award_type; diff --git a/migrations/000001_tables.up.sql b/migrations/000001_tables.up.sql new file mode 100644 index 0000000..d73c099 --- /dev/null +++ b/migrations/000001_tables.up.sql @@ -0,0 +1,145 @@ +CREATE TABLE roles ( + id SERIAL PRIMARY KEY, + name VARCHAR(20) UNIQUE NOT NULL, + created_problems_limit INTEGER NOT NULL, + created_contests_limit INTEGER NOT NULL, + is_default BOOLEAN DEFAULT false NOT NULL, + created_at TIMESTAMP DEFAULT now() NOT NULL +); + +INSERT INTO roles (name, created_problems_limit, created_contests_limit, is_default) VALUES + ('admin', -1, -1, false), + ('unlimited', -1, -1, false), + ('limited', 10, 2, true), + ('banned', 0, 0, false); + +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username VARCHAR(50) UNIQUE NOT NULL, + password_hash VARCHAR(255) NOT NULL, + role_id INTEGER NOT NULL REFERENCES roles(id) ON DELETE RESTRICT, + address VARCHAR(48) DEFAULT '' NOT NULL, + created_at TIMESTAMP DEFAULT now() NOT NULL +); + +CREATE INDEX idx_users_role_id ON users(role_id); +CREATE UNIQUE INDEX unique_user_address ON users(address) WHERE address <> ''; + +CREATE TABLE wallets ( + id SERIAL PRIMARY KEY, + address VARCHAR(48) UNIQUE NOT NULL, + mnemonic_encrypted TEXT NOT NULL, + created_at TIMESTAMP DEFAULT now() NOT NULL +); + +CREATE TABLE payments ( + id SERIAL PRIMARY KEY, + tx_hash VARCHAR(64) UNIQUE NOT NULL, + from_address VARCHAR(48) NOT NULL, + to_address VARCHAR(48) NOT NULL, + amount_ton_nanos BIGINT NOT NULL CHECK (amount_ton_nanos >= 0), + is_incoming BOOLEAN NOT NULL, + created_at TIMESTAMP DEFAULT now() NOT NULL +); + +CREATE TYPE award_type AS ENUM ('no', 'pool', 'sponsored'); + +CREATE TABLE contests ( + id SERIAL PRIMARY KEY, + creator_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title VARCHAR(64) NOT NULL, + description VARCHAR(300) DEFAULT '' NOT NULL, + award_type award_type NOT NULL, + entry_price_ton_nanos BIGINT DEFAULT 0 NOT NULL, + distribution_payment_id INTEGER REFERENCES payments(id) ON DELETE SET NULL, + start_time TIMESTAMP NOT NULL, + end_time TIMESTAMP NOT NULL, + duration_mins INTEGER NOT NULL CHECK (duration_mins >= 0), + max_entries INTEGER DEFAULT 0 NOT NULL CHECK (max_entries >= 0), + allow_late_join BOOLEAN DEFAULT true NOT NULL, + wallet_id INTEGER REFERENCES wallets(id) ON DELETE RESTRICT, + created_at TIMESTAMP DEFAULT now() NOT NULL, + CHECK (start_time < end_time) +); + +CREATE INDEX idx_contests_creator_id ON contests(creator_id); +CREATE INDEX idx_contests_distribution_payment_id ON contests(distribution_payment_id); +CREATE INDEX idx_contests_wallet_id ON contests(wallet_id); +CREATE UNIQUE INDEX unique_contest_wallet_id ON contests(wallet_id) WHERE wallet_id IS NOT NULL; + +CREATE TYPE difficulty AS ENUM ('easy', 'mid', 'hard'); + +CREATE TABLE problems ( + id SERIAL PRIMARY KEY, + writer_id INTEGER NOT NULL REFERENCES users(id) ON DELETE RESTRICT, + title VARCHAR(64) NOT NULL, + statement TEXT NOT NULL, + difficulty difficulty NOT NULL, + time_limit_ms INTEGER DEFAULT 2000 NOT NULL CHECK (time_limit_ms >= 0), + memory_limit_mb INTEGER DEFAULT 128 NOT NULL CHECK (memory_limit_mb >= 0), + checker VARCHAR(10) NOT NULL DEFAULT 'tokens', + created_at TIMESTAMP DEFAULT now() NOT NULL +); + +CREATE INDEX idx_problems_writer_id ON problems(writer_id); + +CREATE TABLE test_cases ( + id SERIAL PRIMARY KEY, + problem_id INTEGER NOT NULL REFERENCES problems(id) ON DELETE CASCADE, + ordinal INTEGER NOT NULL, + input TEXT NOT NULL, + output TEXT NOT NULL, + is_example BOOLEAN DEFAULT false NOT NULL, + UNIQUE (problem_id, ordinal) +); + +CREATE INDEX idx_test_cases_problem_id ON test_cases(problem_id); + +CREATE TABLE contest_problems ( + contest_id INTEGER NOT NULL REFERENCES contests(id) ON DELETE CASCADE, + problem_id INTEGER NOT NULL REFERENCES problems(id) ON DELETE CASCADE, + charcode VARCHAR(2) NOT NULL, + PRIMARY KEY (contest_id, problem_id), + UNIQUE (contest_id, charcode) +); + +CREATE TABLE entries ( + id SERIAL PRIMARY KEY, + contest_id INTEGER NOT NULL REFERENCES contests(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + payment_id INTEGER REFERENCES payments(id) ON DELETE SET NULL, + created_at TIMESTAMP DEFAULT now() NOT NULL, + UNIQUE (contest_id, user_id) +); + +CREATE INDEX idx_entries_contest_id ON entries(contest_id); +CREATE INDEX idx_entries_user_id ON entries(user_id); +CREATE INDEX idx_entries_payment_id ON entries(payment_id); + +CREATE TABLE submissions ( + id SERIAL PRIMARY KEY, + entry_id INTEGER NOT NULL REFERENCES entries(id) ON DELETE CASCADE, + problem_id INTEGER NOT NULL REFERENCES problems(id) ON DELETE CASCADE, + status VARCHAR(20) NOT NULL DEFAULT 'pending', + verdict VARCHAR(30) NOT NULL DEFAULT 'not_judged', + code TEXT NOT NULL, + language VARCHAR(20) NOT NULL, + created_at TIMESTAMP DEFAULT now() NOT NULL +); + +CREATE INDEX idx_submissions_entry_id ON submissions(entry_id); +CREATE INDEX idx_submissions_problem_id ON submissions(problem_id); + +CREATE TABLE testing_reports ( + id SERIAL PRIMARY KEY, + submission_id INTEGER NOT NULL REFERENCES submissions(id) ON DELETE CASCADE, + passed_tests_count INTEGER DEFAULT 0 NOT NULL CHECK (passed_tests_count >= 0), + total_tests_count INTEGER DEFAULT 0 NOT NULL CHECK (total_tests_count >= 0), + first_failed_test_id INTEGER REFERENCES test_cases(id), + first_failed_test_output TEXT DEFAULT '', + stderr TEXT DEFAULT '' NOT NULL, + created_at TIMESTAMP DEFAULT now() NOT NULL +); + +CREATE INDEX idx_testing_reports_submission_id ON testing_reports(submission_id); +CREATE INDEX idx_testing_reports_first_failed_test_id ON testing_reports(first_failed_test_id); diff --git a/migrations/000002_views.down.sql b/migrations/000002_views.down.sql new file mode 100644 index 0000000..723b730 --- /dev/null +++ b/migrations/000002_views.down.sql @@ -0,0 +1,6 @@ +DROP VIEW IF EXISTS submissions_view; +DROP VIEW IF EXISTS contest_problems_view; +DROP VIEW IF EXISTS problem_statuses; +DROP VIEW IF EXISTS problems_view; +DROP VIEW IF EXISTS contests_view; +DROP VIEW IF EXISTS scores; diff --git a/migrations/000002_views.up.sql b/migrations/000002_views.up.sql new file mode 100644 index 0000000..bcf9f77 --- /dev/null +++ b/migrations/000002_views.up.sql @@ -0,0 +1,118 @@ +CREATE VIEW scores AS +SELECT + e.contest_id, + u.id AS user_id, + u.username, + COALESCE(SUM( + CASE + WHEN p.difficulty = 'easy' THEN 1 + WHEN p.difficulty = 'mid' THEN 2 + WHEN p.difficulty = 'hard' THEN 3 + ELSE 0 + END + ), 0) AS points +FROM users u +JOIN entries e ON u.id = e.user_id +JOIN contests c ON e.contest_id = c.id +LEFT JOIN ( + SELECT DISTINCT s.entry_id, s.problem_id + FROM submissions s + JOIN entries e2 ON s.entry_id = e2.id + JOIN contests c2 ON e2.contest_id = c2.id + WHERE s.verdict = 'ok' + AND s.created_at <= c2.end_time +) s ON e.id = s.entry_id +LEFT JOIN problems p ON s.problem_id = p.id +GROUP BY e.contest_id, u.id, u.username; + + +CREATE VIEW contests_view AS +SELECT + c.id, + c.creator_id, + u.username AS creator_username, + u.address AS creator_address, + c.title, + c.description, + c.award_type, + c.entry_price_ton_nanos, + c.distribution_payment_id, + c.start_time, + c.end_time, + c.duration_mins, + c.max_entries, + c.allow_late_join, + c.wallet_id, + COUNT(e.id) AS participants_count, + c.created_at +FROM contests c +JOIN users u ON u.id = c.creator_id +LEFT JOIN entries e ON e.contest_id = c.id +GROUP BY c.id, u.username, u.address; + + +CREATE VIEW problems_view AS +SELECT + p.id, + p.writer_id, + u.username AS writer_username, + u.address AS writer_address, + p.title, + p.statement, + p.difficulty, + p.time_limit_ms, + p.memory_limit_mb, + p.checker, + p.created_at +FROM problems p +JOIN users u ON u.id = p.writer_id; + + +CREATE VIEW problem_statuses AS +SELECT + s.problem_id, + s.entry_id, + CASE + WHEN COUNT(*) FILTER (WHERE s.verdict = 'ok') > 0 THEN 'accepted' + WHEN COUNT(*) > 0 THEN 'tried' + END AS status +FROM submissions s +GROUP BY s.entry_id, s.problem_id; + + +CREATE VIEW contest_problems_view AS +SELECT + p.id AS problem_id, + cp.charcode, + cp.contest_id, + p.writer_id, + u.username AS writer_username, + u.address AS writer_address, + p.title, + p.statement, + p.difficulty, + p.time_limit_ms, + p.memory_limit_mb, + p.checker, + p.created_at +FROM problems p +JOIN contest_problems cp ON p.id = cp.problem_id +JOIN users u ON u.id = p.writer_id; + + +CREATE VIEW submissions_view AS +SELECT + s.id, + s.entry_id, + e.contest_id, + s.problem_id, + e.user_id, + u.username, + s.status, + s.verdict, + s.code, + s.language, + s.created_at +FROM submissions s +JOIN entries e ON s.entry_id = e.id +JOIN users u ON e.user_id = u.id; diff --git a/pkg/echoctx/echoctx_test.go b/pkg/echoctx/echoctx_test.go new file mode 100644 index 0000000..64c7d85 --- /dev/null +++ b/pkg/echoctx/echoctx_test.go @@ -0,0 +1,112 @@ +package echoctx + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestLookup(t *testing.T) { + e := echo.New() + + t.Run("returns value and true when key exists with correct type", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + c.Set("test-string", "hello") + c.Set("test-int", 42) + c.Set("test-bool", true) + + str, ok := Lookup[string](c, "test-string") + assert.True(t, ok) + assert.Equal(t, "hello", str) + + num, ok := Lookup[int](c, "test-int") + assert.True(t, ok) + assert.Equal(t, 42, num) + + boolVal, ok := Lookup[bool](c, "test-bool") + assert.True(t, ok) + assert.True(t, boolVal) + }) + + t.Run("returns zero value and false when key does not exist", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + str, ok := Lookup[string](c, "nonexistent") + assert.False(t, ok) + assert.Equal(t, "", str) + + num, ok := Lookup[int](c, "nonexistent") + assert.False(t, ok) + assert.Equal(t, 0, num) + }) + + t.Run("returns zero value and false when type mismatch", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + c.Set("test-string", "hello") + + // Try to get string value as int + num, ok := Lookup[int](c, "test-string") + assert.False(t, ok) + assert.Equal(t, 0, num) + + // Try to get string value as bool + boolVal, ok := Lookup[bool](c, "test-string") + assert.False(t, ok) + assert.False(t, boolVal) + }) + + t.Run("works with custom types", func(t *testing.T) { + type CustomStruct struct { + Name string + Age int + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + expected := CustomStruct{Name: "John", Age: 30} + c.Set("custom", expected) + + result, ok := Lookup[CustomStruct](c, "custom") + assert.True(t, ok) + assert.Equal(t, expected, result) + }) + + t.Run("works with pointer types", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + str := "hello" + c.Set("ptr", &str) + + ptr, ok := Lookup[*string](c, "ptr") + assert.True(t, ok) + assert.NotNil(t, ptr) + assert.Equal(t, "hello", *ptr) + }) + + t.Run("returns false when value is nil", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + c.Set("nil-value", nil) + + str, ok := Lookup[string](c, "nil-value") + assert.False(t, ok) + assert.Equal(t, "", str) + }) +} diff --git a/pkg/ratelimit/ratelimit_test.go b/pkg/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..9977779 --- /dev/null +++ b/pkg/ratelimit/ratelimit_test.go @@ -0,0 +1,254 @@ +package ratelimit + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestWithTimeout(t *testing.T) { + e := echo.New() + + t.Run("allows first request from IP", func(t *testing.T) { + middleware := WithTimeout(1 * time.Second) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Real-IP", "192.168.1.1") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("blocks second request from same IP within timeout", func(t *testing.T) { + middleware := WithTimeout(500 * time.Millisecond) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Real-IP", "192.168.1.2") + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + + err := handler(c1) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec1.Code) + + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Real-IP", "192.168.1.2") + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + + err = handler(c2) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, rec2.Code) + }) + + t.Run("allows request after timeout expires", func(t *testing.T) { + duration := 100 * time.Millisecond + middleware := WithTimeout(duration) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Real-IP", "192.168.1.3") + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + + err := handler(c1) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec1.Code) + + time.Sleep(duration + 10*time.Millisecond) + + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Real-IP", "192.168.1.3") + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + + err = handler(c2) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec2.Code) + }) + + t.Run("tracks different IPs independently", func(t *testing.T) { + middleware := WithTimeout(1 * time.Second) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Real-IP", "192.168.1.4") + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + + err := handler(c1) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec1.Code) + + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Real-IP", "192.168.1.5") + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + + err = handler(c2) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec2.Code) + }) + + t.Run("returns JSON error response with timeout information", func(t *testing.T) { + duration := 2 * time.Second + middleware := WithTimeout(duration) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Real-IP", "192.168.1.6") + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + + _ = handler(c1) + + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Real-IP", "192.168.1.6") + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + + err := handler(c2) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, rec2.Code) + assert.Contains(t, rec2.Body.String(), "rate limit exceeded") + assert.Contains(t, rec2.Body.String(), "timeout") + }) + + t.Run("uses RealIP from context", func(t *testing.T) { + middleware := WithTimeout(500 * time.Millisecond) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Forwarded-For", "10.0.0.1") + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + + err := handler(c1) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec1.Code) + + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Forwarded-For", "10.0.0.1") + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + + err = handler(c2) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, rec2.Code) + }) + + t.Run("allows multiple requests within sequence correctly", func(t *testing.T) { + duration := 100 * time.Millisecond + middleware := WithTimeout(duration) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + ip := "192.168.1.7" + + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Real-IP", ip) + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + err := handler(c1) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec1.Code) + + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Real-IP", ip) + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + err = handler(c2) + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, rec2.Code) + + time.Sleep(duration + 10*time.Millisecond) + + req3 := httptest.NewRequest(http.MethodGet, "/", nil) + req3.Header.Set("X-Real-IP", ip) + rec3 := httptest.NewRecorder() + c3 := e.NewContext(req3, rec3) + err = handler(c3) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec3.Code) + }) + + t.Run("handles concurrent requests from different IPs", func(t *testing.T) { + middleware := WithTimeout(1 * time.Second) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + successCount := 0 + done := make(chan bool, 10) + + for i := 0; i < 10; i++ { + go func(index int) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Real-IP", "192.168.1."+string(rune(100+index))) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + _ = handler(c) + if rec.Code == http.StatusOK { + done <- true + } else { + done <- false + } + }(i) + } + + for i := 0; i < 10; i++ { + if <-done { + successCount++ + } + } + + assert.Equal(t, 10, successCount) + }) + + t.Run("timeout value in response is accurate", func(t *testing.T) { + duration := 5 * time.Second + middleware := WithTimeout(duration) + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req1 := httptest.NewRequest(http.MethodGet, "/", nil) + req1.Header.Set("X-Real-IP", "192.168.1.8") + rec1 := httptest.NewRecorder() + c1 := e.NewContext(req1, rec1) + _ = handler(c1) + + req2 := httptest.NewRequest(http.MethodGet, "/", nil) + req2.Header.Set("X-Real-IP", "192.168.1.8") + rec2 := httptest.NewRecorder() + c2 := e.NewContext(req2, rec2) + _ = handler(c2) + + body := rec2.Body.String() + assert.Contains(t, body, "timeout") + assert.Contains(t, body, "5s") + }) +} diff --git a/pkg/requestid/requestid_test.go b/pkg/requestid/requestid_test.go new file mode 100644 index 0000000..77b7b44 --- /dev/null +++ b/pkg/requestid/requestid_test.go @@ -0,0 +1,126 @@ +package requestid + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + e := echo.New() + + t.Run("generates and sets request ID when not present", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + called := false + handler := New(func(c echo.Context) error { + called = true + rid := c.Request().Header.Get(headerRequestID) + assert.NotEmpty(t, rid) + _, err := uuid.Parse(rid) + assert.NoError(t, err) + return nil + }) + + err := handler(c) + assert.NoError(t, err) + assert.True(t, called) + }) + + t.Run("preserves existing request ID from response header", func(t *testing.T) { + existingID := "existing-request-id" + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + rec.Header().Set(headerRequestID, existingID) + c := e.NewContext(req, rec) + + called := false + handler := New(func(c echo.Context) error { + called = true + return nil + }) + + err := handler(c) + assert.NoError(t, err) + assert.True(t, called) + }) + + t.Run("calls next handler", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + nextCalled := false + handler := New(func(c echo.Context) error { + nextCalled = true + return nil + }) + + err := handler(c) + assert.NoError(t, err) + assert.True(t, nextCalled) + }) + + t.Run("propagates error from next handler", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + expectedErr := echo.NewHTTPError(http.StatusInternalServerError, "test error") + handler := New(func(c echo.Context) error { + return expectedErr + }) + + err := handler(c) + assert.Equal(t, expectedErr, err) + }) +} + +func TestGet(t *testing.T) { + e := echo.New() + + t.Run("returns request ID from header", func(t *testing.T) { + expectedID := "test-request-id-123" + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(headerRequestID, expectedID) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + rid := Get(c) + assert.Equal(t, expectedID, rid) + }) + + t.Run("returns empty string when request ID not set", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + rid := Get(c) + assert.Equal(t, "", rid) + }) + + t.Run("integration: Get returns ID set by New middleware", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedID string + handler := New(func(c echo.Context) error { + capturedID = Get(c) + return nil + }) + + err := handler(c) + assert.NoError(t, err) + assert.NotEmpty(t, capturedID) + + _, err = uuid.Parse(capturedID) + assert.NoError(t, err) + }) +} diff --git a/pkg/requestlog/requestlog_test.go b/pkg/requestlog/requestlog_test.go new file mode 100644 index 0000000..34726e4 --- /dev/null +++ b/pkg/requestlog/requestlog_test.go @@ -0,0 +1,253 @@ +package requestlog + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/voidcontests/api/internal/app/handler" +) + +func TestCompleted(t *testing.T) { + e := echo.New() + + t.Run("calls next handler and logs request", func(t *testing.T) { + nextCalled := false + middleware := Completed(func(c echo.Context) error { + nextCalled = true + return c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("X-Request-ID", "test-request-id") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.NoError(t, err) + assert.True(t, nextCalled) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("skips OPTIONS requests", func(t *testing.T) { + nextCalled := false + middleware := Completed(func(c echo.Context) error { + nextCalled = true + return c.NoContent(http.StatusNoContent) + }) + + req := httptest.NewRequest(http.MethodOptions, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.NoError(t, err) + assert.True(t, nextCalled) + }) + + t.Run("skips logging for healthcheck endpoint with 200 status", func(t *testing.T) { + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusOK, "healthy") + }) + + req := httptest.NewRequest(http.MethodGet, "/api/healthcheck", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPath("/api/healthcheck") + + err := middleware(c) + assert.NoError(t, err) + }) + + t.Run("logs healthcheck endpoint with non-200 status", func(t *testing.T) { + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusInternalServerError, "unhealthy") + }) + + req := httptest.NewRequest(http.MethodGet, "/api/healthcheck", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPath("/api/healthcheck") + + err := middleware(c) + assert.NoError(t, err) + }) + + t.Run("propagates error from next handler", func(t *testing.T) { + expectedErr := errors.New("test error") + middleware := Completed(func(c echo.Context) error { + return expectedErr + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.Equal(t, expectedErr, err) + }) + + t.Run("handles APIError and extracts status code", func(t *testing.T) { + apiErr := &handler.APIError{ + Status: http.StatusBadRequest, + Message: "bad request", + } + + middleware := Completed(func(c echo.Context) error { + return apiErr + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("X-Request-ID", "test-id") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.Equal(t, apiErr, err) + }) + + t.Run("handles non-APIError and uses status 500", func(t *testing.T) { + genericErr := errors.New("generic error") + + middleware := Completed(func(c echo.Context) error { + return genericErr + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("X-Request-ID", "test-id") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.Equal(t, genericErr, err) + }) + + t.Run("logs correct request information", func(t *testing.T) { + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusCreated, "created") + }) + + req := httptest.NewRequest(http.MethodPost, "/api/users", nil) + req.Header.Set("X-Request-ID", "req-123") + req.Header.Set("User-Agent", "TestAgent/1.0") + req.Host = "example.com" + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPath("/api/users") + + err := middleware(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusCreated, rec.Code) + }) + + t.Run("measures request duration", func(t *testing.T) { + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("X-Request-ID", "test-id") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.NoError(t, err) + }) + + t.Run("handles different HTTP methods", func(t *testing.T) { + methods := []string{ + http.MethodGet, + http.MethodPost, + http.MethodPut, + http.MethodPatch, + http.MethodDelete, + } + + for _, method := range methods { + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest(method, "/test", nil) + req.Header.Set("X-Request-ID", "test-id") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + } + }) + + t.Run("retrieves request ID from context", func(t *testing.T) { + expectedID := "custom-request-id-456" + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("X-Request-ID", expectedID) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.NoError(t, err) + }) + + t.Run("logs RealIP from context", func(t *testing.T) { + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("X-Real-IP", "203.0.113.42") + req.Header.Set("X-Request-ID", "test-id") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.NoError(t, err) + }) + + t.Run("handles empty user agent", func(t *testing.T) { + middleware := Completed(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Header.Set("X-Request-ID", "test-id") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := middleware(c) + assert.NoError(t, err) + }) + + t.Run("integration: full request lifecycle", func(t *testing.T) { + handlerCalled := false + middleware := Completed(func(c echo.Context) error { + handlerCalled = true + return c.JSON(http.StatusOK, map[string]string{ + "message": "success", + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/api/data", nil) + req.Header.Set("X-Request-ID", "integration-test-id") + req.Header.Set("User-Agent", "IntegrationTest/1.0") + req.Header.Set("X-Real-IP", "192.168.1.100") + req.Host = "api.example.com" + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetPath("/api/data") + + err := middleware(c) + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Contains(t, rec.Body.String(), "success") + }) +} diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go new file mode 100644 index 0000000..5ef13b2 --- /dev/null +++ b/pkg/scheduler/scheduler.go @@ -0,0 +1,61 @@ +package scheduler + +import ( + "context" + "log/slog" + "time" + + "github.com/voidcontests/api/internal/lib/logger/sl" +) + +type Task func(ctx context.Context) error + +type Scheduler struct { + interval time.Duration + task Task + stop chan struct{} + done chan struct{} +} + +func New(interval time.Duration, task Task) *Scheduler { + return &Scheduler{ + interval: interval, + task: task, + stop: make(chan struct{}), + done: make(chan struct{}), + } +} + +func (s *Scheduler) Start(ctx context.Context) { + defer close(s.done) + + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + slog.Info("scheduler: started", slog.String("interval", s.interval.String())) + + if err := s.task(ctx); err != nil { + slog.Error("scheduler: task execution failed", sl.Err(err)) + } + + for { + select { + case <-ticker.C: + if err := s.task(ctx); err != nil { + slog.Error("scheduler: task execution failed", sl.Err(err)) + } + case <-s.stop: + slog.Info("scheduler: stopping...") + return + case <-ctx.Done(): + slog.Info("scheduler: context cancelled, stopping...") + return + } + } +} + +func (s *Scheduler) Stop() { + close(s.stop) + <-s.done + slog.Info("scheduler: stopped") +} diff --git a/pkg/scheduler/scheduler_test.go b/pkg/scheduler/scheduler_test.go new file mode 100644 index 0000000..1fa9ef2 --- /dev/null +++ b/pkg/scheduler/scheduler_test.go @@ -0,0 +1,298 @@ +package scheduler + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + t.Run("creates scheduler with correct interval and task", func(t *testing.T) { + interval := 100 * time.Millisecond + task := func(ctx context.Context) error { + return nil + } + + s := New(interval, task) + + assert.NotNil(t, s) + assert.Equal(t, interval, s.interval) + assert.NotNil(t, s.task) + assert.NotNil(t, s.stop) + assert.NotNil(t, s.done) + }) +} + +func TestScheduler_Start(t *testing.T) { + t.Run("executes task immediately on start", func(t *testing.T) { + executed := false + var mu sync.Mutex + + task := func(ctx context.Context) error { + mu.Lock() + executed = true + mu.Unlock() + return nil + } + + s := New(1*time.Hour, task) + ctx := context.Background() + + go s.Start(ctx) + time.Sleep(50 * time.Millisecond) + s.Stop() + + mu.Lock() + assert.True(t, executed) + mu.Unlock() + }) + + t.Run("executes task periodically", func(t *testing.T) { + var count int + var mu sync.Mutex + + task := func(ctx context.Context) error { + mu.Lock() + count++ + mu.Unlock() + return nil + } + + interval := 50 * time.Millisecond + s := New(interval, task) + ctx := context.Background() + + go s.Start(ctx) + time.Sleep(160 * time.Millisecond) + s.Stop() + + mu.Lock() + assert.GreaterOrEqual(t, count, 3) + mu.Unlock() + }) + + t.Run("stops when Stop is called", func(t *testing.T) { + var count int + var mu sync.Mutex + + task := func(ctx context.Context) error { + mu.Lock() + count++ + mu.Unlock() + return nil + } + + interval := 30 * time.Millisecond + s := New(interval, task) + ctx := context.Background() + + go s.Start(ctx) + time.Sleep(100 * time.Millisecond) + s.Stop() + + mu.Lock() + countAfterStop := count + mu.Unlock() + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + assert.Equal(t, countAfterStop, count, "task should not execute after Stop") + mu.Unlock() + }) + + t.Run("stops when context is cancelled", func(t *testing.T) { + var count int + var mu sync.Mutex + + task := func(ctx context.Context) error { + mu.Lock() + count++ + mu.Unlock() + return nil + } + + interval := 30 * time.Millisecond + s := New(interval, task) + ctx, cancel := context.WithCancel(context.Background()) + + go s.Start(ctx) + time.Sleep(100 * time.Millisecond) + cancel() + + time.Sleep(50 * time.Millisecond) + + mu.Lock() + countAfterCancel := count + mu.Unlock() + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + assert.Equal(t, countAfterCancel, count, "task should not execute after context cancellation") + mu.Unlock() + }) + + t.Run("continues execution when task returns error", func(t *testing.T) { + var count int + var mu sync.Mutex + + task := func(ctx context.Context) error { + mu.Lock() + count++ + mu.Unlock() + return errors.New("test error") + } + + interval := 50 * time.Millisecond + s := New(interval, task) + ctx := context.Background() + + go s.Start(ctx) + time.Sleep(160 * time.Millisecond) + s.Stop() + + mu.Lock() + assert.GreaterOrEqual(t, count, 3, "scheduler should continue despite task errors") + mu.Unlock() + }) + + t.Run("task receives correct context", func(t *testing.T) { + type contextKey string + key := contextKey("test-key") + expectedValue := "test-value" + + var receivedValue string + var mu sync.Mutex + + task := func(ctx context.Context) error { + mu.Lock() + if val := ctx.Value(key); val != nil { + receivedValue = val.(string) + } + mu.Unlock() + return nil + } + + interval := 1 * time.Hour + s := New(interval, task) + ctx := context.WithValue(context.Background(), key, expectedValue) + + go s.Start(ctx) + time.Sleep(50 * time.Millisecond) + s.Stop() + + mu.Lock() + assert.Equal(t, expectedValue, receivedValue) + mu.Unlock() + }) +} + +func TestScheduler_Stop(t *testing.T) { + t.Run("waits for scheduler to finish", func(t *testing.T) { + taskStarted := make(chan struct{}) + taskCanFinish := make(chan struct{}) + + task := func(ctx context.Context) error { + close(taskStarted) + <-taskCanFinish + return nil + } + + s := New(1*time.Hour, task) + ctx := context.Background() + + go s.Start(ctx) + + <-taskStarted + + stopFinished := make(chan struct{}) + go func() { + s.Stop() + close(stopFinished) + }() + + select { + case <-stopFinished: + t.Fatal("Stop should wait for task to finish") + case <-time.After(50 * time.Millisecond): + } + + close(taskCanFinish) + + select { + case <-stopFinished: + case <-time.After(100 * time.Millisecond): + t.Fatal("Stop did not complete in time") + } + }) + + t.Run("Stop waits for graceful shutdown", func(t *testing.T) { + task := func(ctx context.Context) error { + return nil + } + + s := New(1*time.Hour, task) + ctx := context.Background() + + go s.Start(ctx) + time.Sleep(50 * time.Millisecond) + + done := make(chan struct{}) + go func() { + s.Stop() + close(done) + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("Stop deadlocked or hung") + } + }) +} + +func TestScheduler_Integration(t *testing.T) { + t.Run("realistic usage scenario", func(t *testing.T) { + var executionTimes []time.Time + var mu sync.Mutex + + task := func(ctx context.Context) error { + mu.Lock() + executionTimes = append(executionTimes, time.Now()) + mu.Unlock() + return nil + } + + interval := 40 * time.Millisecond + s := New(interval, task) + ctx := context.Background() + + start := time.Now() + go s.Start(ctx) + time.Sleep(150 * time.Millisecond) + s.Stop() + elapsed := time.Since(start) + + mu.Lock() + count := len(executionTimes) + mu.Unlock() + + assert.GreaterOrEqual(t, count, 3) + assert.Less(t, elapsed, 200*time.Millisecond) + + mu.Lock() + if len(executionTimes) >= 2 { + for i := 1; i < len(executionTimes); i++ { + gap := executionTimes[i].Sub(executionTimes[i-1]) + assert.Greater(t, gap, 30*time.Millisecond) + assert.Less(t, gap, 60*time.Millisecond) + } + } + mu.Unlock() + }) +} diff --git a/pkg/ton/address.go b/pkg/ton/address.go new file mode 100644 index 0000000..527c87e --- /dev/null +++ b/pkg/ton/address.go @@ -0,0 +1,7 @@ +package ton + +import "github.com/xssnick/tonutils-go/address" + +func (c *Client) GetAddress(addr *address.Address) string { + return addr.Testnet(c.testnet).String() +} diff --git a/pkg/ton/balance.go b/pkg/ton/balance.go new file mode 100644 index 0000000..21db088 --- /dev/null +++ b/pkg/ton/balance.go @@ -0,0 +1,70 @@ +package ton + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/xssnick/tonutils-go/address" +) + +const balanceCacheTTL = 5 * time.Minute + +type balanceCacheEntry struct { + balance uint64 + expiresAt time.Time +} + +func (c *Client) GetBalance(ctx context.Context, address *address.Address) (uint64, error) { + block, err := c.api.CurrentMasterchainInfo(ctx) + if err != nil { + return 0, fmt.Errorf("failed to get masterchain info: %w", err) + } + + account, err := c.api.GetAccount(ctx, block, address) + if err != nil { + return 0, fmt.Errorf("failed to get account: %w", err) + } + + if !account.IsActive { + return 0, nil + } + + return account.State.Balance.Nano().Uint64(), nil +} + +func (c *Client) GetBalanceCached(ctx context.Context, address *address.Address) (uint64, error) { + cacheKey := c.GetAddress(address) + if cached, ok := c.balanceCache.Load(cacheKey); ok { + entry := cached.(*balanceCacheEntry) + if time.Now().Before(entry.expiresAt) { + slog.Info("cache: returned balance from cache") + return entry.balance, nil + } + c.balanceCache.Delete(cacheKey) + } + + balance, err := c.GetBalance(ctx, address) + if err != nil { + return 0, err + } + + c.balanceCache.Store(cacheKey, &balanceCacheEntry{ + balance: balance, + expiresAt: time.Now().Add(balanceCacheTTL), + }) + + return balance, nil +} + +func (c *Client) ClearBalanceCache() { + c.balanceCache.Range(func(key, value interface{}) bool { + c.balanceCache.Delete(key) + return true + }) +} + +func (c *Client) InvalidateBalance(address *address.Address) { + c.balanceCache.Delete(address.String()) +} diff --git a/pkg/ton/client.go b/pkg/ton/client.go new file mode 100644 index 0000000..1d7d0e3 --- /dev/null +++ b/pkg/ton/client.go @@ -0,0 +1,139 @@ +package ton + +import ( + "context" + "fmt" + "strings" + "sync" + + "github.com/tonkeeper/tongo/tonconnect" + "github.com/voidcontests/api/internal/config" + "github.com/xssnick/tonutils-go/address" + "github.com/xssnick/tonutils-go/liteclient" + "github.com/xssnick/tonutils-go/tlb" + tonutils "github.com/xssnick/tonutils-go/ton" + "github.com/xssnick/tonutils-go/ton/wallet" +) + +const ( + MainnetID = "-239" + TestnetID = "-3" +) + +type Client struct { + api tonutils.APIClientWrapped + testnet bool + TonConnect *tonconnect.Server + balanceCache sync.Map // map[string]*balanceCacheEntry, key is address string +} + +func NewClient(ctx context.Context, c *config.Ton) (*Client, error) { + client := liteclient.NewConnectionPool() + err := client.AddConnectionsFromConfigUrl(ctx, c.ConfigURL) + if err != nil { + return nil, err + } + + unsafeAPI := tonutils.NewAPIClient(client, tonutils.ProofCheckPolicyUnsafe) + block, err := unsafeAPI.GetMasterchainInfo(ctx) + if err != nil { + return nil, err + } + + api := tonutils.NewAPIClient(client, tonutils.ProofCheckPolicySecure).WithRetry() + api.SetTrustedBlock(block) + + tc := &Client{ + api: api, + testnet: c.IsTestnet, + } + + payloadLifetime := int64(c.Proof.PayloadLifetime.Seconds()) + proofLifetime := int64(c.Proof.ProofLifetime.Seconds()) + + tcserver, err := tonconnect.NewTonConnect( + tc, + c.Proof.PayloadSignatureKey, + tonconnect.WithLifeTimePayload(payloadLifetime), + tonconnect.WithLifeTimeProof(proofLifetime), + ) + if err != nil { + return nil, fmt.Errorf("tonconnect: can't initialize: %w", err) + } + + tc.TonConnect = tcserver + + return tc, nil +} + +func (c *Client) CreateWallet() (*Wallet, error) { + return c.WalletWithSeed(strings.Join(wallet.NewSeed(), " ")) +} + +func (c *Client) WalletWithSeed(mnemonic string) (*Wallet, error) { + words := strings.Split(mnemonic, " ") + w, err := wallet.FromSeedWithOptions(c.api, words, wallet.V4R2) + if err != nil { + return nil, fmt.Errorf("failed to create wallet from seed: %w", err) + } + + return &Wallet{ + address: w.WalletAddress(), + Mnemonic: words, + Instance: w, + testnet: c.testnet, + }, nil +} + +func FromNano(nano uint64) string { + return tlb.FromNanoTONU(nano).String() +} + +func (c *Client) LookupTx(ctx context.Context, from *address.Address, to *address.Address, amount tlb.Coins) (string, bool) { + block, err := c.api.CurrentMasterchainInfo(ctx) + if err != nil { + return "", false + } + + account, err := c.api.GetAccount(ctx, block, to) + if err != nil { + return "", false + } + + if !account.IsActive { + return "", false + } + + txs, err := c.api.ListTransactions(ctx, to, 100, account.LastTxLT, account.LastTxHash) + if err != nil { + return "", false + } + + for _, tx := range txs { + if tx.IO.In == nil || tx.IO.In.MsgType != tlb.MsgTypeInternal { + continue + } + + inmsg := tx.IO.In.AsInternal() + if inmsg == nil { + continue + } + + if inmsg.SrcAddr.Equals(from) { + // checks if in tx transferred at least `amount` + if inmsg.Amount.Nano().Cmp(amount.Nano()) >= 0 { + return fmt.Sprintf("%x", tx.Hash), true + } + } + } + + return "", false +} + +func (c *Client) IsTestnet() bool { + return c.testnet +} + +func (c *Client) API() tonutils.APIClientWrapped { + return c.api +} diff --git a/pkg/ton/executor.go b/pkg/ton/executor.go new file mode 100644 index 0000000..f9adb86 --- /dev/null +++ b/pkg/ton/executor.go @@ -0,0 +1,33 @@ +package ton + +import ( + "context" + + "github.com/tonkeeper/tongo/tlb" + "github.com/tonkeeper/tongo/ton" + "github.com/xssnick/tonutils-go/address" +) + +// RunSmcMethodByID is an implementation of `abi.Executor` interface for `tonconnect.NewTonConnect` function +func (c *Client) RunSmcMethodByID(ctx context.Context, accountID ton.AccountID, methodID int, params tlb.VmStack) (uint32, tlb.VmStack, error) { + rawAddr := accountID.ToRaw() + addr, err := address.ParseAddr(rawAddr) + if err != nil { + return 0, tlb.VmStack{}, err + } + + block, err := c.api.CurrentMasterchainInfo(ctx) + if err != nil { + return 0, tlb.VmStack{}, err + } + + var tuParams []interface{} + + methodName := "" + _, err = c.api.RunGetMethod(ctx, block, addr, methodName, tuParams...) + if err != nil { + return 0, tlb.VmStack{}, err + } + + return 0, tlb.VmStack{}, nil +} diff --git a/pkg/ton/ton_test.go b/pkg/ton/ton_test.go new file mode 100644 index 0000000..9b7a79a --- /dev/null +++ b/pkg/ton/ton_test.go @@ -0,0 +1,244 @@ +package ton + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/xssnick/tonutils-go/address" +) + +func TestFromNano(t *testing.T) { + tests := []struct { + name string + nano uint64 + expected string + }{ + { + name: "zero nanotons", + nano: 0, + expected: "0", + }, + { + name: "one TON", + nano: 1_000_000_000, + expected: "1", + }, + { + name: "fractional TON", + nano: 500_000_000, + expected: "0.5", + }, + { + name: "small amount", + nano: 1, + expected: "0.000000001", + }, + { + name: "large amount", + nano: 1_234_567_890_000, + expected: "1234.56789", + }, + { + name: "10.5 TON", + nano: 10_500_000_000, + expected: "10.5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FromNano(tt.nano) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestClient_IsTestnet(t *testing.T) { + t.Run("returns true for testnet client", func(t *testing.T) { + client := &Client{testnet: true} + assert.True(t, client.IsTestnet()) + }) + + t.Run("returns false for mainnet client", func(t *testing.T) { + client := &Client{testnet: false} + assert.False(t, client.IsTestnet()) + }) +} + +func TestClient_GetAddress(t *testing.T) { + t.Run("formats address for testnet", func(t *testing.T) { + addr, err := address.ParseRawAddr("0:8156fbabb2c8c8119dab794ca096f57c3af1775549469f0d1b4e766d8e613c36") + assert.NoError(t, err) + client := &Client{testnet: true} + + result := client.GetAddress(addr) + assert.NotEmpty(t, result) + assert.True(t, result[0] == 'k' || result[0] == '0') + }) + + t.Run("formats address for mainnet", func(t *testing.T) { + addr, err := address.ParseRawAddr("0:8156fbabb2c8c8119dab794ca096f57c3af1775549469f0d1b4e766d8e613c36") + assert.NoError(t, err) + client := &Client{testnet: false} + + result := client.GetAddress(addr) + assert.NotEmpty(t, result) + assert.True(t, result[0] == 'E' || result[0] == '0') + }) +} + +func TestWallet_Address(t *testing.T) { + t.Run("returns testnet address when testnet is true", func(t *testing.T) { + addr, err := address.ParseRawAddr("0:8156fbabb2c8c8119dab794ca096f57c3af1775549469f0d1b4e766d8e613c36") + assert.NoError(t, err) + wallet := &Wallet{ + address: addr, + testnet: true, + } + + result := wallet.Address() + assert.NotNil(t, result) + }) + + t.Run("returns mainnet address when testnet is false", func(t *testing.T) { + addr, err := address.ParseRawAddr("0:8156fbabb2c8c8119dab794ca096f57c3af1775549469f0d1b4e766d8e613c36") + assert.NoError(t, err) + wallet := &Wallet{ + address: addr, + testnet: false, + } + + result := wallet.Address() + assert.NotNil(t, result) + }) +} + +func TestClient_ClearBalanceCache(t *testing.T) { + t.Run("clears all cached balances", func(t *testing.T) { + client := &Client{} + + addr1, err := address.ParseRawAddr("0:8156fbabb2c8c8119dab794ca096f57c3af1775549469f0d1b4e766d8e613c36") + assert.NoError(t, err) + addr2, err := address.ParseRawAddr("0:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + assert.NoError(t, err) + + key1 := client.GetAddress(addr1) + key2 := client.GetAddress(addr2) + + client.balanceCache.Store(key1, &balanceCacheEntry{balance: 1000}) + client.balanceCache.Store(key2, &balanceCacheEntry{balance: 2000}) + + _, ok1 := client.balanceCache.Load(key1) + _, ok2 := client.balanceCache.Load(key2) + assert.True(t, ok1) + assert.True(t, ok2) + + client.ClearBalanceCache() + + _, ok1 = client.balanceCache.Load(key1) + _, ok2 = client.balanceCache.Load(key2) + assert.False(t, ok1) + assert.False(t, ok2) + }) + + t.Run("works with empty cache", func(t *testing.T) { + client := &Client{} + + assert.NotPanics(t, func() { + client.ClearBalanceCache() + }) + }) +} + +func TestClient_InvalidateBalance(t *testing.T) { + t.Run("invalidates specific address balance", func(t *testing.T) { + client := &Client{} + + addr1, err := address.ParseRawAddr("0:8156fbabb2c8c8119dab794ca096f57c3af1775549469f0d1b4e766d8e613c36") + assert.NoError(t, err) + addr2, err := address.ParseRawAddr("0:1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef") + assert.NoError(t, err) + + key1 := addr1.String() + key2 := addr2.String() + + client.balanceCache.Store(key1, &balanceCacheEntry{balance: 1000}) + client.balanceCache.Store(key2, &balanceCacheEntry{balance: 2000}) + + client.InvalidateBalance(addr1) + + _, ok1 := client.balanceCache.Load(key1) + _, ok2 := client.balanceCache.Load(key2) + assert.False(t, ok1) + assert.True(t, ok2) + }) + + t.Run("works when address not in cache", func(t *testing.T) { + client := &Client{} + addr, err := address.ParseRawAddr("0:8156fbabb2c8c8119dab794ca096f57c3af1775549469f0d1b4e766d8e613c36") + assert.NoError(t, err) + + assert.NotPanics(t, func() { + client.InvalidateBalance(addr) + }) + }) +} + +func TestConstants(t *testing.T) { + t.Run("mainnet ID is correct", func(t *testing.T) { + assert.Equal(t, "-239", MainnetID) + }) + + t.Run("testnet ID is correct", func(t *testing.T) { + assert.Equal(t, "-3", TestnetID) + }) +} + +/* +Integration Tests (require actual blockchain connection): + +The following tests would require a connection to TON blockchain (testnet or mainnet) +and should be run as integration tests with proper setup: + +1. TestNewClient - requires valid config and network connection +2. TestClient_CreateWallet - requires API client +3. TestClient_WalletWithSeed - requires API client and valid mnemonic +4. TestClient_GetBalance - requires API client and blockchain query +5. TestClient_GetBalanceCached - requires API client, tests caching behavior +6. TestClient_LookupTx - requires API client and actual transactions +7. TestWallet_TransferTo - requires API client, funded wallet, and actual transfer +8. TestClient_RunSmcMethodByID - requires API client and smart contract + +Example integration test structure: + +func TestIntegration_CreateWallet(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test") + } + + ctx := context.Background() + cfg := &config.Ton{ + ConfigURL: "https: IsTestnet: true, + Proof: config.TonProof{ + PayloadLifetime: time.Minute * 5, + ProofLifetime: time.Minute * 5, + PayloadSignatureKey: "test-key", + }, + } + + client, err := NewClient(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, client) + + wallet, err := client.CreateWallet() + require.NoError(t, err) + require.NotNil(t, wallet) + require.NotNil(t, wallet.Address()) + require.Len(t, wallet.Mnemonic, 24) +} + +To run integration tests: + go test -v ./pkg/ton/... -run Integration +To skip integration tests: + go test -v ./pkg/ton/... -short +*/ diff --git a/pkg/ton/wallet.go b/pkg/ton/wallet.go new file mode 100644 index 0000000..acd9097 --- /dev/null +++ b/pkg/ton/wallet.go @@ -0,0 +1,52 @@ +package ton + +import ( + "context" + "fmt" + "strings" + + "github.com/xssnick/tonutils-go/address" + "github.com/xssnick/tonutils-go/tlb" + "github.com/xssnick/tonutils-go/ton/wallet" + "github.com/xssnick/tonutils-go/tvm/cell" +) + +type Wallet struct { + address *address.Address + Mnemonic []string + Instance *wallet.Wallet + testnet bool +} + +func (w *Wallet) TransferTo(ctx context.Context, recipient *address.Address, amount tlb.Coins, comments ...string) (tx string, err error) { + var body *cell.Cell + + if len(comments) > 0 { + comment := strings.Join(comments, ", ") + body, err = wallet.CreateCommentCell(comment) + if err != nil { + return "", fmt.Errorf("failed to create comment cell: %w", err) + } + } + + transaction, _, err := w.Instance.SendWaitTransaction(ctx, &wallet.Message{ + Mode: wallet.PayGasSeparately + wallet.IgnoreErrors, + InternalMessage: &tlb.InternalMessage{ + IHRDisabled: true, + Bounce: false, + DstAddr: recipient, + Amount: amount, + Body: body, + }, + }) + + if err != nil { + return "", fmt.Errorf("failed to send transaction: %w", err) + } + + return fmt.Sprintf("%x", transaction.Hash), nil +} + +func (w *Wallet) Address() *address.Address { + return w.address.Testnet(w.testnet) +} diff --git a/pkg/validate/validate_test.go b/pkg/validate/validate_test.go new file mode 100644 index 0000000..5ba6d59 --- /dev/null +++ b/pkg/validate/validate_test.go @@ -0,0 +1,361 @@ +package validate + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestBind(t *testing.T) { + e := echo.New() + + t.Run("successfully binds and validates valid JSON", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name" required:"true"` + Email string `json:"email" required:"true"` + Age int `json:"age"` + } + + jsonBody := `{"name":"John","email":"john@example.com","age":30}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonBody)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var result TestStruct + err := Bind(c, &result) + + assert.NoError(t, err) + assert.Equal(t, "John", result.Name) + assert.Equal(t, "john@example.com", result.Email) + assert.Equal(t, 30, result.Age) + }) + + t.Run("returns error when required field is missing", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name" required:"true"` + Email string `json:"email" required:"true"` + } + + jsonBody := `{"name":"John"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonBody)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var result TestStruct + err := Bind(c, &result) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "email") + assert.Contains(t, err.Error(), "required") + }) + + t.Run("returns error when destination is not a pointer", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + } + + jsonBody := `{"name":"John"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonBody)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var result TestStruct + err := Bind(c, result) + assert.Error(t, err) + assert.Contains(t, err.Error(), "pointer") + }) + + t.Run("returns error on invalid JSON", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + } + + jsonBody := `{"name":invalid}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonBody)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var result TestStruct + err := Bind(c, &result) + + assert.Error(t, err) + }) + + t.Run("sets Content-Type header to application/json", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + } + + jsonBody := `{"name":"John"}` + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(jsonBody)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var result TestStruct + _ = Bind(c, &result) + + assert.Equal(t, "application/json", c.Request().Header.Get("Content-Type")) + }) +} + +func TestStruct(t *testing.T) { + t.Run("validates struct with all required fields present", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name" required:"true"` + Email string `json:"email" required:"true"` + Age int `json:"age" required:"true"` + } + + data := &TestStruct{ + Name: "John", + Email: "john@example.com", + Age: 30, + } + + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when string field is empty", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name" required:"true"` + } + + data := &TestStruct{Name: ""} + err := Struct(data) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "name") + assert.Contains(t, err.Error(), "required") + }) + + t.Run("returns error when int field is zero or negative", func(t *testing.T) { + type TestStruct struct { + Age int `json:"age" required:"true"` + } + + data := &TestStruct{Age: 0} + err := Struct(data) + assert.Error(t, err) + + data = &TestStruct{Age: -1} + err = Struct(data) + assert.Error(t, err) + }) + + t.Run("validates positive int field", func(t *testing.T) { + type TestStruct struct { + Age int `json:"age" required:"true"` + } + + data := &TestStruct{Age: 30} + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when slice is empty", func(t *testing.T) { + type TestStruct struct { + Tags []string `json:"tags" required:"true"` + } + + data := &TestStruct{Tags: []string{}} + err := Struct(data) + assert.Error(t, err) + }) + + t.Run("validates non-empty slice", func(t *testing.T) { + type TestStruct struct { + Tags []string `json:"tags" required:"true"` + } + + data := &TestStruct{Tags: []string{"tag1", "tag2"}} + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("validates nested structs", func(t *testing.T) { + type Address struct { + Street string `json:"street" required:"true"` + City string `json:"city" required:"true"` + } + + type Person struct { + Name string `json:"name" required:"true"` + Address Address `json:"address"` + } + + data := &Person{ + Name: "John", + Address: Address{ + Street: "123 Main St", + City: "New York", + }, + } + + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when nested struct has missing required field", func(t *testing.T) { + type Address struct { + Street string `json:"street" required:"true"` + City string `json:"city" required:"true"` + } + + type Person struct { + Name string `json:"name" required:"true"` + Address Address `json:"address"` + } + + data := &Person{ + Name: "John", + Address: Address{ + Street: "123 Main St", + City: ""}, + } + + err := Struct(data) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Address") + assert.Contains(t, err.Error(), "city") + }) + + t.Run("handles time.Time fields without validating them as structs", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name" required:"true"` + CreatedAt time.Time `json:"created_at"` + } + + data := &TestStruct{ + Name: "John", + CreatedAt: time.Now(), + } + + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when destination is not a pointer", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + } + + data := TestStruct{Name: "John"} + err := Struct(data) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "pointer") + }) + + t.Run("returns error when destination is not a struct", func(t *testing.T) { + data := "not a struct" + err := Struct(&data) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "struct") + }) + + t.Run("validates pointer fields", func(t *testing.T) { + type TestStruct struct { + Name *string `json:"name" required:"true"` + } + + name := "John" + data := &TestStruct{Name: &name} + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when pointer field is nil", func(t *testing.T) { + type TestStruct struct { + Name *string `json:"name" required:"true"` + } + + data := &TestStruct{Name: nil} + err := Struct(data) + assert.Error(t, err) + }) + + t.Run("validates map fields", func(t *testing.T) { + type TestStruct struct { + Metadata map[string]string `json:"metadata" required:"true"` + } + + data := &TestStruct{Metadata: map[string]string{"key": "value"}} + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when map field is empty", func(t *testing.T) { + type TestStruct struct { + Metadata map[string]string `json:"metadata" required:"true"` + } + + data := &TestStruct{Metadata: map[string]string{}} + err := Struct(data) + assert.Error(t, err) + }) + + t.Run("validates float fields", func(t *testing.T) { + type TestStruct struct { + Price float64 `json:"price" required:"true"` + } + + data := &TestStruct{Price: 9.99} + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when float field is zero or negative", func(t *testing.T) { + type TestStruct struct { + Price float64 `json:"price" required:"true"` + } + + data := &TestStruct{Price: 0.0} + err := Struct(data) + assert.Error(t, err) + + data = &TestStruct{Price: -1.5} + err = Struct(data) + assert.Error(t, err) + }) + + t.Run("validates uint fields", func(t *testing.T) { + type TestStruct struct { + Count uint `json:"count" required:"true"` + } + + data := &TestStruct{Count: 10} + err := Struct(data) + assert.NoError(t, err) + }) + + t.Run("returns error when uint field is zero", func(t *testing.T) { + type TestStruct struct { + Count uint `json:"count" required:"true"` + } + + data := &TestStruct{Count: 0} + err := Struct(data) + assert.Error(t, err) + }) + + t.Run("ignores fields without required tag", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name" required:"true"` + Optional string `json:"optional"` + } + + data := &TestStruct{ + Name: "John", + Optional: ""} + + err := Struct(data) + assert.NoError(t, err) + }) +}