diff --git a/.env-sample b/.env-sample index 1a8a81ab..09053add 100644 --- a/.env-sample +++ b/.env-sample @@ -23,3 +23,9 @@ VPN_DEPLOYER_API_SG=abcd EREBRUS_API_US_EAST=abcd EREBRUS_API_SG=abcd GOOGLE_AUDIENCE=abcd +EREBRUS_API_US_EAST=abcd +EREBRUS_API_SG=abcd +STRIPE_SECRET_KEY=sk_test_51JZtirA56jIbQVxXgbZcpr8OJ7kk3NLzGvyyure9Qrrt5fdxOUaZLOxxSh2JdNoBVv6LBgM595IfmAogqs9miIt200QTRNCFuB +STRIPE_SUCCESS_URL=http://localhost:3000/success +STRIPE_CANCEL_URL=http://localhost:3000/cancel +STRIPE_WEBHOOK_SECRET=whsec_4sG3SuGhIxlBrztcJshn8gOB8VXwPpb4 \ No newline at end of file diff --git a/api/v1/plan/cancel.go b/api/v1/plan/cancel.go new file mode 100644 index 00000000..44705800 --- /dev/null +++ b/api/v1/plan/cancel.go @@ -0,0 +1,42 @@ +package plan + +import ( + "net/http" + + "github.com/NetSepio/gateway/api/middleware/auth/paseto" + "github.com/NetSepio/gateway/config/dbconfig" + "github.com/NetSepio/gateway/models" + "github.com/NetSepio/gateway/util/pkg/logwrapper" + "github.com/TheLazarusNetwork/go-helpers/httpo" + "github.com/gin-gonic/gin" + "github.com/stripe/stripe-go/v76" + "github.com/stripe/stripe-go/v76/subscription" +) + +func CancelStripeSubscription(c *gin.Context) { + userId := c.GetString(paseto.CTX_USER_ID) + db := dbconfig.GetDb() + + var user models.User + if err := db.Where("user_id = ?", userId).First(&user).Error; err != nil { + httpo.NewErrorResponse(http.StatusInternalServerError, "User not found").SendD(c) + return + } + + if user.SubscriptionStatus == "basic" || user.StripeSubscriptionId == nil { + httpo.NewErrorResponse(http.StatusBadRequest, "No active subscription to cancel").SendD(c) + return + } + + // Proceed to cancel the subscription + _, err := subscription.Update(*user.StripeSubscriptionId, &stripe.SubscriptionParams{ + CancelAtPeriodEnd: stripe.Bool(true), + }) + if err != nil { + logwrapper.Errorf("Stripe subscription cancellation failed: %v", err) + httpo.NewErrorResponse(http.StatusInternalServerError, "Failed to cancel subscription").SendD(c) + return + } + + httpo.NewSuccessResponse(http.StatusOK, "Subscription cancelled successfully").SendD(c) +} diff --git a/api/v1/plan/create.go b/api/v1/plan/create.go new file mode 100644 index 00000000..671a938a --- /dev/null +++ b/api/v1/plan/create.go @@ -0,0 +1,86 @@ +package plan + +import ( + "fmt" + "net/http" + + "github.com/NetSepio/gateway/api/middleware/auth/paseto" + "github.com/NetSepio/gateway/config/dbconfig" + "github.com/NetSepio/gateway/config/envconfig" + "github.com/NetSepio/gateway/models" + "github.com/NetSepio/gateway/util/pkg/logwrapper" + "github.com/TheLazarusNetwork/go-helpers/httpo" + "github.com/gin-gonic/gin" + "github.com/stripe/stripe-go/v76" + "github.com/stripe/stripe-go/v76/checkout/session" + "github.com/stripe/stripe-go/v76/customer" +) + +func CreateStripeSession(c *gin.Context) { + userId := c.GetString(paseto.CTX_USER_ID) + var req struct { + PriceID string `json:"priceId" binding:"required"` + } + if err := c.BindJSON(&req); err != nil { + httpo.NewErrorResponse(http.StatusBadRequest, fmt.Sprintf("Invalid request: %s", err)).SendD(c) + return + } + db := dbconfig.GetDb() + var user models.User + if err := db.Where("user_id = ?", userId).First(&user).Error; err != nil { + logwrapper.Errorf("Failed to find user: %v", err) + httpo.NewErrorResponse(http.StatusInternalServerError, "internal server error").SendD(c) + return + } + + var customerID string + + // Check if stripe_customer_id is null + if user.StripeCustomerId == "" { + // Create a new Stripe customer + customerParams := &stripe.CustomerParams{} + stripeCustomer, err := customer.New(customerParams) + if err != nil { + logwrapper.Errorf("Stripe customer creation failed: %v", err) + httpo.NewErrorResponse(http.StatusInternalServerError, "internal server error").SendD(c) + return + } + customerID = stripeCustomer.ID + + // Update user with new stripe_customer_id + user.StripeCustomerId = stripeCustomer.ID + if err := db.Save(&user).Error; err != nil { + logwrapper.Errorf("Failed to update user: %v", err) + httpo.NewErrorResponse(http.StatusInternalServerError, "internal server error").SendD(c) + return + } + } else { + customerID = user.StripeCustomerId + } + + params := &stripe.CheckoutSessionParams{ + PaymentMethodTypes: stripe.StringSlice([]string{"card"}), + LineItems: []*stripe.CheckoutSessionLineItemParams{ + { + Price: stripe.String(req.PriceID), + Quantity: stripe.Int64(1), + }, + }, + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + SuccessURL: stripe.String(envconfig.EnvVars.STRIPE_SUCCESS_URL), + CancelURL: stripe.String(envconfig.EnvVars.STRIPE_CANCEL_URL), + ClientReferenceID: stripe.String(userId), + Customer: stripe.String(customerID), + } + + s, err := session.New(params) + if err != nil { + logwrapper.Errorf("Stripe session creation failed: %v", err) + httpo.NewErrorResponse(http.StatusInternalServerError, "Failed to create Stripe session").SendD(c) + return + } + + fmt.Println("customer ", s.Customer) + + httpo.NewSuccessResponseP(http.StatusOK, "Session created successfully", gin.H{"session_url": s.URL}).SendD(c) +} diff --git a/api/v1/plan/plan.go b/api/v1/plan/plan.go new file mode 100644 index 00000000..e5eccfb5 --- /dev/null +++ b/api/v1/plan/plan.go @@ -0,0 +1,16 @@ +package plan + +import ( + "github.com/NetSepio/gateway/api/middleware/auth/paseto" + "github.com/gin-gonic/gin" +) + +func ApplyRoutes(r *gin.RouterGroup) { + plan := r.Group("/plan") + { + plan.POST("/webhook", StripeWebhookHandler) + plan.Use(paseto.PASETO(false)) + plan.POST("/", CreateStripeSession) + plan.DELETE("/", CancelStripeSubscription) + } +} diff --git a/api/v1/plan/webhook.go b/api/v1/plan/webhook.go new file mode 100644 index 00000000..e50fe4a0 --- /dev/null +++ b/api/v1/plan/webhook.go @@ -0,0 +1,87 @@ +package plan + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + + "github.com/NetSepio/gateway/config/dbconfig" + "github.com/NetSepio/gateway/config/envconfig" + "github.com/NetSepio/gateway/models" + "github.com/NetSepio/gateway/util/pkg/logwrapper" + "github.com/gin-gonic/gin" + "github.com/stripe/stripe-go/v76" + "github.com/stripe/stripe-go/v76/webhook" +) + +func updateSubscriptionStatus(customerID, subscriptionStatus string, stripeSubscriptionId *string, stripeSubscriptionStatus stripe.SubscriptionStatus) error { + db := dbconfig.GetDb() + var user models.User + if err := db.Where("stripe_customer_id = ?", customerID).First(&user).Error; err != nil { + return err + } + + user.StripeSubscriptionId = stripeSubscriptionId + user.SubscriptionStatus = subscriptionStatus + user.StripeSubscriptionStatus = stripeSubscriptionStatus + return db.Save(&user).Error +} + +func StripeWebhookHandler(c *gin.Context) { + const MaxBodyBytes = int64(65536) + c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, MaxBodyBytes) + payload, err := io.ReadAll(c.Request.Body) + if err != nil { + logwrapper.Errorf("Error reading request body: %v", err) + c.Status(http.StatusServiceUnavailable) + return + } + + event, err := webhook.ConstructEvent(payload, c.GetHeader("Stripe-Signature"), envconfig.EnvVars.STRIPE_WEBHOOK_SECRET) + if err != nil { + logwrapper.Errorf("Error verifying webhook signature: %v", err) + c.Status(http.StatusBadRequest) + return + } + switch event.Type { + case stripe.EventTypeCustomerSubscriptionDeleted: + var subscription stripe.Subscription + err := json.Unmarshal(event.Data.Raw, &subscription) + if err != nil { + fmt.Fprintf(os.Stderr, "Error parsing webhook JSON: %v\n", err) + c.Status(http.StatusInternalServerError) + return + } + if err := updateSubscriptionStatus(subscription.Customer.ID, "basic", nil, "unset"); err != nil { + logwrapper.Errorf("Error updating subscription status: %v", err) + c.Status(http.StatusInternalServerError) + return + } + + case stripe.EventTypeCustomerSubscriptionUpdated: + var subscription stripe.Subscription + err := json.Unmarshal(event.Data.Raw, &subscription) + if err != nil { + fmt.Fprintf(os.Stderr, "Error parsing webhook JSON: %v\n", err) + // w.WriteHeader(http.StatusBadRequest) + return + } + if subscription.Status == "active" { + if err := updateSubscriptionStatus(subscription.Customer.ID, subscription.Items.Data[0].Price.LookupKey, &subscription.ID, subscription.Status); err != nil { + logwrapper.Errorf("Error updating subscription status: %v", err) + c.Status(http.StatusInternalServerError) + return + } + } else { + if err := updateSubscriptionStatus(subscription.Customer.ID, "basic", &subscription.ID, subscription.Status); err != nil { + logwrapper.Errorf("Error updating subscription status: %v", err) + c.Status(http.StatusInternalServerError) + return + } + } + } + + c.JSON(http.StatusOK, gin.H{"status": "received"}) +} diff --git a/api/v1/profile/profile.go b/api/v1/profile/profile.go index 02df63b6..7fa1ca1b 100644 --- a/api/v1/profile/profile.go +++ b/api/v1/profile/profile.go @@ -60,7 +60,7 @@ func getProfile(c *gin.Context) { db := dbconfig.GetDb() userId := c.GetString(paseto.CTX_USER_ID) var user models.User - err := db.Model(&models.User{}).Select("user_id, name, profile_picture_url,country, wallet_address, discord, twitter, email_id").Where("user_id = ?", userId).First(&user).Error + err := db.Model(&models.User{}).Select("user_id, name, profile_picture_url,country, wallet_address, discord, twitter, email_id, subscription_status").Where("user_id = ?", userId).First(&user).Error if err != nil { logrus.Error(err) httpo.NewErrorResponse(http.StatusInternalServerError, "Unexpected error occured").SendD(c) @@ -68,7 +68,7 @@ func getProfile(c *gin.Context) { } payload := GetProfilePayload{ - user.UserId, user.Name, user.WalletAddress, user.ProfilePictureUrl, user.Country, user.Discord, user.Twitter, user.EmailId, + user.UserId, user.Name, user.WalletAddress, user.ProfilePictureUrl, user.Country, user.Discord, user.Twitter, user.EmailId, user.SubscriptionStatus, } httpo.NewSuccessResponseP(200, "Profile fetched successfully", payload).SendD(c) } diff --git a/api/v1/profile/types.go b/api/v1/profile/types.go index 321d4a74..5e742ed0 100644 --- a/api/v1/profile/types.go +++ b/api/v1/profile/types.go @@ -9,12 +9,13 @@ type PatchProfileRequest struct { } type GetProfilePayload struct { - UserId string `json:"userId,omitempty"` - Name string `json:"name,omitempty"` - WalletAddress *string `json:"walletAddress,omitempty"` - ProfilePictureUrl string `json:"profilePictureUrl,omitempty"` - Country string `json:"country,omitempty"` - Discord string `json:"discord,omitempty"` - Twitter string `json:"twitter,omitempty"` - Email *string `json:"email,omitempty"` + UserId string `json:"userId,omitempty"` + Name string `json:"name,omitempty"` + WalletAddress string `json:"walletAddress"` + ProfilePictureUrl string `json:"profilePictureUrl,omitempty"` + Country string `json:"country,omitempty"` + Discord string `json:"discord,omitempty"` + Twitter string `json:"twitter,omitempty"` + Email *string `json:"email,omitempty"` + Plan string `json:"plan,omitempty"` } diff --git a/api/v1/v1.go b/api/v1/v1.go index d0394ba6..c762e3cb 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -11,6 +11,7 @@ import ( flowid "github.com/NetSepio/gateway/api/v1/flowid" "github.com/NetSepio/gateway/api/v1/getreviewerdetails" "github.com/NetSepio/gateway/api/v1/getreviews" + "github.com/NetSepio/gateway/api/v1/plan" "github.com/NetSepio/gateway/api/v1/profile" "github.com/NetSepio/gateway/api/v1/report" "github.com/NetSepio/gateway/api/v1/siteinsights" @@ -43,5 +44,6 @@ func ApplyRoutes(r *gin.RouterGroup) { report.ApplyRoutes(v1) account.ApplyRoutes(v1) siteinsights.ApplyRoutes(v1) + plan.ApplyRoutes(v1) } } diff --git a/app/app.go b/app/app.go index 3be7db82..df5f3460 100644 --- a/app/app.go +++ b/app/app.go @@ -5,6 +5,7 @@ import ( "github.com/NetSepio/gateway/api" "github.com/NetSepio/gateway/util/pkg/logwrapper" + "github.com/stripe/stripe-go/v76" "github.com/NetSepio/gateway/config/constants" "github.com/NetSepio/gateway/config/dbconfig" @@ -19,6 +20,7 @@ func Init() { envconfig.InitEnvVars() constants.InitConstants() logwrapper.Init() + stripe.Key = envconfig.EnvVars.STRIPE_SECRET_KEY GinApp = gin.Default() diff --git a/config/envconfig/envconfig.go b/config/envconfig/envconfig.go index df3f8b79..da62cdde 100644 --- a/config/envconfig/envconfig.go +++ b/config/envconfig/envconfig.go @@ -34,6 +34,10 @@ type config struct { EREBRUS_API_SG string `env:"EREBRUS_API_SG,notEmpty"` GOOGLE_AUDIENCE string `env:"GOOGLE_AUDIENCE,notEmpty"` OPENAI_API_KEY string `env:"OPENAI_API_KEY,notEmpty"` + STRIPE_SECRET_KEY string `env:"STRIPE_SECRET_KEY,notEmpty"` + STRIPE_SUCCESS_URL string `env:"STRIPE_SUCCESS_URL,notEmpty"` + STRIPE_CANCEL_URL string `env:"STRIPE_CANCEL_URL,notEmpty"` + STRIPE_WEBHOOK_SECRET string `env:"STRIPE_WEBHOOK_SECRET,notEmpty"` } var EnvVars config = config{} diff --git a/go.mod b/go.mod index 363a60f8..db386e8b 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/joho/godotenv v1.4.0 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.8.4 + github.com/stripe/stripe-go/v76 v76.10.0 github.com/vk-rv/pvx v0.0.0-20210912195928-ac00bc32f6e7 golang.org/x/crypto v0.16.0 google.golang.org/api v0.154.0 diff --git a/go.sum b/go.sum index 1ed2ad7a..0d797bd7 100644 --- a/go.sum +++ b/go.sum @@ -545,6 +545,8 @@ github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stripe/stripe-go/v76 v76.10.0 h1:DRC1XnE1yfz972Oqvg9eZSdLctCRU8e+hHIXHOKVEt4= +github.com/stripe/stripe-go/v76 v76.10.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7 h1:epCh84lMvA70Z7CTTCmYQn2CKbY8j86K7/FAIr141uY= github.com/syndtr/goleveldb v1.0.1-0.20210819022825-2ae1ddf74ef7/go.mod h1:q4W45IWZaF22tdD+VEXcAWRA037jwmWEB5VWYORlTpc= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= @@ -672,6 +674,7 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210220033124-5f55cee0dc0d/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= diff --git a/main.go b/main.go index 4e0eb8f2..e20a9288 100644 --- a/main.go +++ b/main.go @@ -25,8 +25,10 @@ func main() { if os.Getenv("DEBUG_MODE") == "true" { walletAddrLower := strings.ToLower("0xdd3933022e36e9a0a15d0522e20b7b580d38b54ec9cb28ae09697ce0f7c95b6b") newUser := &models.User{ - WalletAddress: &walletAddrLower, - UserId: "fc8fe270-ce16-4df9-a17f-979bcd824e32", + UserId: "fc8fe270-ce16-4df9-a17f-979bcd824e98", + SubscriptionStatus: "basic", + StripeSubscriptionStatus: "unset", + WalletAddress: &walletAddrLower, } if err := db.Create(newUser).Error; err != nil { logwrapper.Warn(err) diff --git a/migrations/000013_user_payment.up.sql b/migrations/000013_user_payment.up.sql new file mode 100644 index 00000000..f8c207da --- /dev/null +++ b/migrations/000013_user_payment.up.sql @@ -0,0 +1,34 @@ +ALTER TABLE + users +ADD + COLUMN subscription_status text CHECK ( + subscription_status IN ( + 'basic', + 'pro monthly', + 'pro yearly' + ) + ), +ADD + COLUMN stripe_customer_id text UNIQUE, +ADD + COLUMN stripe_subscription_status text CHECK ( + stripe_subscription_status IN ( + 'incomplete', + 'incomplete_expired', + 'trialing', + 'active', + 'past_due', + 'canceled', + 'unpaid', + 'unset' + ) + ) DEFAULT 'unset', +ADD + COLUMN stripe_subscription_id text UNIQUE; + +-- set basic subscription status and unset stripe_subscription_status for all users +UPDATE + users +SET + subscription_status = 'basic', + stripe_subscription_status = 'unset'; \ No newline at end of file diff --git a/models/User.go b/models/User.go index c442c0e0..dd13133b 100644 --- a/models/User.go +++ b/models/User.go @@ -1,14 +1,20 @@ package models +import "github.com/stripe/stripe-go/v76" + type User struct { - UserId string `gorm:"primary_key" json:"userId,omitempty"` - Name string `json:"name,omitempty"` - WalletAddress *string `json:"walletAddress,omitempty"` - Discord string `json:"discord"` - Twitter string `json:"twitter"` - FlowIds []FlowId `gorm:"foreignkey:UserId" json:"-"` - ProfilePictureUrl string `json:"profilePictureUrl,omitempty"` - Country string `json:"country,omitempty"` - Feedbacks []UserFeedback `gorm:"foreignkey:UserId" json:"userFeedbacks"` - EmailId *string `json:"emailId,omitempty"` + UserId string `gorm:"primary_key" json:"userId,omitempty"` + Name string `json:"name,omitempty"` + WalletAddress string `gorm:"unique" json:"walletAddress"` + Discord string `json:"discord"` + Twitter string `json:"twitter"` + FlowIds []FlowId `gorm:"foreignkey:UserId" json:"-"` + ProfilePictureUrl string `json:"profilePictureUrl,omitempty"` + Country string `json:"country,omitempty"` + Feedbacks []UserFeedback `gorm:"foreignkey:UserId" json:"userFeedbacks"` + EmailId *string `json:"emailId,omitempty"` + SubscriptionStatus string `json:"subscriptionStatus,omitempty"` + StripeCustomerId string `json:"-"` + StripeSubscriptionId *string `json:"stripeSubscriptionId,omitempty"` + StripeSubscriptionStatus stripe.SubscriptionStatus `json:"stripeSubscriptionStatus,omitempty"` }