diff --git a/requirements.txt b/requirements.txt index fd954f8..e1bfc43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Core dependencies -discord.py>=2.3.0 +discord.py>=2.4.0 flask>=2.3.0 flask-cors>=4.0.0 python-dotenv>=1.0.0 @@ -24,6 +24,7 @@ Pillow>=10.0.0 # AI and external services openai>=1.0.0 anthropic>=0.7.0 +stripe>=7.0.0 # Data processing and utilities pandas>=2.0.0 diff --git a/src/cogs/ai/__init__.py b/src/cogs/ai/__init__.py new file mode 100644 index 0000000..b197bb3 --- /dev/null +++ b/src/cogs/ai/__init__.py @@ -0,0 +1,10 @@ +"""AI cogs module.""" + + +async def setup(bot): + """Setup function for AI cogs.""" + from .chat import PerplexityChat + from .server_designer import ServerDesigner + + await bot.add_cog(PerplexityChat(bot)) + await bot.add_cog(ServerDesigner(bot)) diff --git a/src/cogs/base.py b/src/cogs/base.py index deb7f92..d8237b5 100644 --- a/src/cogs/base.py +++ b/src/cogs/base.py @@ -67,6 +67,36 @@ async def cog_after_invoke(self, ctx: commands.Context): # Track command usage if needed pass + async def cog_app_command_error( + self, + interaction: discord.Interaction, + error: discord.app_commands.AppCommandError + ): + """Handle slash command errors for all cogs that inherit BaseCog.""" + self.logger.error( + f"App command error: {type(error).__name__}: {error}", + exc_info=error + ) + + if isinstance(error, discord.app_commands.CommandOnCooldown): + msg = f"⏰ This command is on cooldown. Try again in {error.retry_after:.1f} seconds." + elif isinstance(error, discord.app_commands.MissingPermissions): + msg = "❌ You don't have permission to use this command." + elif isinstance(error, discord.app_commands.BotMissingPermissions): + msg = "❌ I don't have the required permissions to execute this command." + elif isinstance(error, discord.app_commands.CheckFailure): + msg = str(error) or "❌ You don't meet the requirements for this command." + else: + msg = "❌ An error occurred while executing this command." + + try: + if interaction.response.is_done(): + await interaction.followup.send(msg, ephemeral=True) + else: + await interaction.response.send_message(msg, ephemeral=True) + except Exception: + pass + async def cog_command_error(self, ctx: commands.Context, error: Exception): """Handle cog-specific errors.""" # Log the error diff --git a/src/cogs/community/__init__.py b/src/cogs/community/__init__.py index 90250ad..2c7b04f 100644 --- a/src/cogs/community/__init__.py +++ b/src/cogs/community/__init__.py @@ -6,10 +6,12 @@ async def setup(bot): from .leveling import Levelling from .registration import Register from .welcome import Welcomer + from .starboard import Starboard from .game_scraper import GameScraper - + await bot.add_cog(AutoRole(bot)) await bot.add_cog(Levelling(bot)) await bot.add_cog(Register(bot)) await bot.add_cog(Welcomer(bot)) + await bot.add_cog(Starboard(bot)) await bot.add_cog(GameScraper(bot)) diff --git a/src/cogs/custom_commands/__init__.py b/src/cogs/custom_commands/__init__.py new file mode 100644 index 0000000..25d66d9 --- /dev/null +++ b/src/cogs/custom_commands/__init__.py @@ -0,0 +1,8 @@ +"""Custom commands cogs module.""" + + +async def setup(bot): + """Setup function for custom commands cog.""" + from .manager import CustomCommandsManager + + await bot.add_cog(CustomCommandsManager(bot)) diff --git a/src/cogs/premium/premium.py b/src/cogs/premium/premium.py index 06cd7ec..6987aef 100644 --- a/src/cogs/premium/premium.py +++ b/src/cogs/premium/premium.py @@ -1,43 +1,59 @@ """Premium subscription management cog.""" import discord -import logging from discord.ext import commands from discord import app_commands from typing import Optional -from src.database.models.premium import PremiumTier, PaymentMethod +from src.cogs.base import BaseCog +from src.database.models.premium import PremiumTier, PaymentMethod, PremiumModel from src.services.crypto_verification import CryptoVerificationService -from src.utils.database.connection import initialize_mongodb -from src.database.models.premium import PremiumModel +from src.services.stripe_service import StripeService -logger = logging.getLogger('premium') - -class Premium(commands.Cog): +class Premium(BaseCog): """Premium subscription management commands""" - + def __init__(self, bot): - self.bot = bot - self.mongo_db = initialize_mongodb() - self.premium_collection = self.mongo_db['premium'] - self.premium_model = PremiumModel(self.premium_collection) + super().__init__(bot) self.crypto_service = CryptoVerificationService() - - # Store premium model in bot for easy access by decorators + self.stripe_service = StripeService() + self.premium_model: Optional[PremiumModel] = None + + async def cog_load(self): + """Initialize premium model after DB is ready.""" + await super().cog_load() + collection = self.async_db.get_collection("premium") + self.premium_model = PremiumModel(collection) + # Make accessible to premium_required decorator self.bot.premium_model = self.premium_model - - logger.info("Premium cog initialized") - - @app_commands.command(name="premium_info", description="Get information about premium subscriptions") - async def premium_info(self, interaction: discord.Interaction): + self.logger.info("Premium cog initialized with async DB") + + # ------------------------------------------------------------------------- + # Command group + # ------------------------------------------------------------------------- + + @commands.hybrid_group(name="premium", description="Premium subscription commands", invoke_without_command=True) + async def premium_group(self, ctx: commands.Context): + """Show premium info when invoked without subcommand.""" + await self._send_info_embed(ctx) + + # ------------------------------------------------------------------------- + # /premium info + # ------------------------------------------------------------------------- + + @premium_group.command(name="info", description="Get information about premium subscriptions") + async def premium_info(self, ctx: commands.Context): """Show premium subscription information""" + await self._send_info_embed(ctx) + + async def _send_info_embed(self, ctx: commands.Context): embed = discord.Embed( title="💎 Premium Subscriptions", - description="Unlock advanced features with a premium subscription!", + description="Unlock advanced features for your server with a premium subscription!", color=discord.Color.gold() ) - + embed.add_field( name="💰 Basic - $5/month", value=( @@ -48,7 +64,7 @@ async def premium_info(self, interaction: discord.Interaction): ), inline=True ) - + embed.add_field( name="🔥 Pro - $15/month", value=( @@ -60,7 +76,7 @@ async def premium_info(self, interaction: discord.Interaction): ), inline=True ) - + embed.add_field( name="⭐ Enterprise - $50/month", value=( @@ -71,76 +87,85 @@ async def premium_info(self, interaction: discord.Interaction): ), inline=True ) - + embed.add_field( name="💳 Payment Methods", - value="We accept BTC, ETH, USDT, and USDC", + value="💳 Credit/Debit Card (Stripe) · BTC · ETH · USDT · USDC", inline=False ) - + embed.add_field( name="📝 How to Subscribe", value=( - "1. Use `/premium_payment ` to get payment details\n" - "2. Send the crypto payment to the provided address\n" - "3. Use `/premium_activate ` to activate your subscription" + "**Card:** `!premium payment stripe` → follow the checkout link\n" + "**Crypto:** `!premium payment ` → send payment → `!premium activate `" ), inline=False ) - - embed.set_footer(text="Premium benefits last for 30 days from activation") - - await interaction.response.send_message(embed=embed) - - @app_commands.command(name="premium_status", description="Check your premium status") - async def premium_status(self, interaction: discord.Interaction): - """Check premium subscription status""" - premium_data = await self.premium_model.get_user_premium( - interaction.user.id, - interaction.guild.id - ) - + + embed.set_footer(text="Premium is server-wide and lasts 30 days from activation") + await ctx.send(embed=embed) + + # ------------------------------------------------------------------------- + # /premium status + # ------------------------------------------------------------------------- + + @premium_group.command(name="status", description="Check this server's premium status") + async def premium_status(self, ctx: commands.Context): + """Check premium subscription status for this guild""" + if not ctx.guild: + await ctx.send("❌ This command must be used in a server.", ephemeral=True) + return + + premium_data = await self.premium_model.get_guild_premium(ctx.guild.id) + if not premium_data: embed = discord.Embed( title="Premium Status", - description="You don't have an active premium subscription.", + description="This server doesn't have an active premium subscription.", color=discord.Color.red() ) embed.add_field( name="Get Premium", - value="Use `/premium_info` to learn about premium benefits!", + value="Use `/premium info` to learn about premium benefits!", inline=False ) else: tier = premium_data.get("premium_tier", "basic") expires_at = premium_data.get("expires_at", "Unknown") - + payment_method = premium_data.get("payment_method", "Unknown") + activated_by = premium_data.get("activated_by") + embed = discord.Embed( title="💎 Premium Status", - description=f"You have an active **{tier.title()}** subscription!", + description=f"This server has an active **{tier.title()}** subscription!", color=discord.Color.gold() ) embed.add_field(name="Tier", value=tier.title(), inline=True) - embed.add_field(name="Expires", value=expires_at.split('T')[0], inline=True) - embed.add_field( - name="Payment Method", - value=premium_data.get("payment_method", "Unknown"), - inline=True - ) - - await interaction.response.send_message(embed=embed, ephemeral=True) - - @app_commands.command(name="premium_payment", description="Get payment details for a premium tier") + expires_display = expires_at.split('T')[0] if isinstance(expires_at, str) else str(expires_at) + embed.add_field(name="Expires", value=expires_display, inline=True) + embed.add_field(name="Payment Method", value=payment_method, inline=True) + if activated_by: + embed.add_field(name="Activated By", value=f"<@{activated_by}>", inline=True) + + await ctx.send(embed=embed, ephemeral=True) + + # ------------------------------------------------------------------------- + # /premium payment + # ------------------------------------------------------------------------- + + @premium_group.command(name="payment", description="Get payment details for a premium tier") @app_commands.describe( tier="Premium tier to subscribe to", - crypto="Cryptocurrency to use for payment" + method="Payment method (stripe, BTC, ETH, USDT, USDC)" ) @app_commands.choices(tier=[ app_commands.Choice(name="Basic ($5/month)", value="basic"), app_commands.Choice(name="Pro ($15/month)", value="pro"), app_commands.Choice(name="Enterprise ($50/month)", value="enterprise"), ]) - @app_commands.choices(crypto=[ + @app_commands.choices(method=[ + app_commands.Choice(name="Credit/Debit Card (Stripe)", value="stripe"), app_commands.Choice(name="Bitcoin (BTC)", value="BTC"), app_commands.Choice(name="Ethereum (ETH)", value="ETH"), app_commands.Choice(name="USDT (Tether)", value="USDT"), @@ -148,67 +173,100 @@ async def premium_status(self, interaction: discord.Interaction): ]) async def premium_payment( self, - interaction: discord.Interaction, + ctx: commands.Context, tier: str, - crypto: str + method: str ): """Get payment information for a premium tier""" - tier_prices = { - "basic": 5.0, - "pro": 15.0, - "enterprise": 50.0 - } - - amount_usd = tier_prices.get(tier, 5.0) + if not ctx.guild: + await ctx.send("❌ This command must be used in a server.", ephemeral=True) + return + + if method.lower() == "stripe": + await self._handle_stripe_payment(ctx, tier) + else: + await self._handle_crypto_payment(ctx, tier, method) + + async def _handle_stripe_payment(self, ctx: commands.Context, tier: str): + """Generate a Stripe checkout link.""" + base_url = "https://contro.space" + result = self.stripe_service.create_checkout_session( + guild_id=ctx.guild.id, + user_id=ctx.author.id, + tier=tier, + success_url=f"{base_url}/premium/success?session_id={{CHECKOUT_SESSION_ID}}", + cancel_url=f"{base_url}/premium/cancel" + ) + + if "error" in result: + await ctx.send(f"❌ {result['error']}", ephemeral=True) + return + + tier_prices = {"basic": 5, "pro": 15, "enterprise": 50} + amount = tier_prices.get(tier.lower(), 0) + + embed = discord.Embed( + title=f"💳 Stripe Checkout - {tier.title()} Premium", + description=f"Pay **${amount}/month** with your credit or debit card.", + color=discord.Color.blue() + ) + embed.add_field( + name="Checkout Link", + value=f"[Click here to pay]({result['checkout_url']})", + inline=False + ) + embed.add_field( + name="ℹ️ Note", + value="Your premium will be activated automatically after successful payment.", + inline=False + ) + embed.set_footer(text="Link expires after 24 hours") + await ctx.send(embed=embed, ephemeral=True) + + async def _handle_crypto_payment(self, ctx: commands.Context, tier: str, crypto: str): + """Show crypto payment details.""" + tier_prices = {"basic": 5.0, "pro": 15.0, "enterprise": 50.0} + amount_usd = tier_prices.get(tier.lower(), 5.0) payment_info = self.crypto_service.get_payment_info(crypto, amount_usd) - + if "error" in payment_info: - await interaction.response.send_message( - f"❌ {payment_info['error']}", - ephemeral=True - ) + await ctx.send(f"❌ {payment_info['error']}", ephemeral=True) return - + embed = discord.Embed( title=f"💳 Payment Details - {tier.title()} Premium", description=f"Send **${amount_usd:.2f} USD** worth of {crypto}", color=discord.Color.blue() ) - embed.add_field( name="Payment Address", value=f"```{payment_info['address']}```", inline=False ) - - embed.add_field( - name="Network", - value=payment_info['network'], - inline=True - ) - + embed.add_field(name="Network", value=payment_info['network'], inline=True) embed.add_field( name="Required Confirmations", value=str(payment_info['min_confirmations']), inline=True ) - embed.add_field( name="⚠️ Important", value=( "• Make sure to send the exact amount\n" "• Wait for confirmations to complete\n" "• Save your transaction hash\n" - f"• Use `/premium_activate ` after payment" + f"• Use `!premium activate ` after payment" ), inline=False ) - embed.set_footer(text="Payment must be completed within 24 hours") - - await interaction.response.send_message(embed=embed, ephemeral=True) - - @app_commands.command(name="premium_activate", description="Activate premium with a transaction hash") + await ctx.send(embed=embed, ephemeral=True) + + # ------------------------------------------------------------------------- + # /premium activate + # ------------------------------------------------------------------------- + + @premium_group.command(name="activate", description="Activate premium with a transaction hash") @app_commands.describe( tx_hash="Blockchain transaction hash", crypto="Cryptocurrency used for payment", @@ -227,65 +285,59 @@ async def premium_payment( ]) async def premium_activate( self, - interaction: discord.Interaction, + ctx: commands.Context, tx_hash: str, crypto: str, tier: str ): """Activate premium subscription with transaction verification""" - await interaction.response.defer(ephemeral=True) - - tier_prices = { - "basic": 5.0, - "pro": 15.0, - "enterprise": 50.0 - } - - expected_amount = tier_prices.get(tier, 5.0) - - # Verify transaction + if not ctx.guild: + await ctx.send("❌ This command must be used in a server.", ephemeral=True) + return + + await ctx.defer(ephemeral=True) + + tier_prices = {"basic": 5.0, "pro": 15.0, "enterprise": 50.0} + expected_amount = tier_prices.get(tier.lower(), 5.0) + verification_result = await self.crypto_service.verify_transaction( - crypto, - tx_hash, - expected_amount + crypto, tx_hash, expected_amount ) - + if not verification_result.get("verified", False): - await interaction.followup.send( + await ctx.send( f"❌ Payment verification failed: {verification_result.get('error', 'Unknown error')}", ephemeral=True ) return - - # Activate premium + success = await self.premium_model.activate_premium( - interaction.user.id, - interaction.guild.id, + ctx.author.id, + ctx.guild.id, tier, crypto, tx_hash ) - + if success: embed = discord.Embed( title="✅ Premium Activated!", - description=f"Your **{tier.title()} Premium** subscription is now active!", + description=f"**{tier.title()} Premium** is now active for **{ctx.guild.name}**!", color=discord.Color.green() ) embed.add_field(name="Duration", value="30 days", inline=True) embed.add_field(name="Transaction", value=f"`{tx_hash[:16]}...`", inline=True) - - await interaction.followup.send(embed=embed, ephemeral=True) + await ctx.send(embed=embed, ephemeral=True) else: - await interaction.followup.send( - "❌ Failed to activate premium. Please contact support.", - ephemeral=True - ) - - # Admin commands - @app_commands.command(name="premium_grant", description="[ADMIN] Grant premium to a user") + await ctx.send("❌ Failed to activate premium. Please contact support.", ephemeral=True) + + # ------------------------------------------------------------------------- + # /premium grant (admin) + # ------------------------------------------------------------------------- + + @premium_group.command(name="grant", description="[ADMIN] Grant premium to this server") @app_commands.describe( - user="User to grant premium to", + user="User who is granted (for record-keeping)", tier="Premium tier to grant", days="Duration in days" ) @@ -297,56 +349,85 @@ async def premium_activate( @commands.has_permissions(administrator=True) async def premium_grant( self, - interaction: discord.Interaction, + ctx: commands.Context, user: discord.Member, tier: str, days: int = 30 ): - """Grant premium to a user (admin only)""" + """Grant premium to a server (admin only)""" + if not ctx.guild: + await ctx.send("❌ This command must be used in a server.", ephemeral=True) + return + success = await self.premium_model.activate_premium( user.id, - interaction.guild.id, + ctx.guild.id, tier, "ADMIN_GRANT", - f"admin_grant_{interaction.user.id}", + f"admin_grant_{ctx.author.id}", days ) - + if success: - await interaction.response.send_message( - f"✅ Granted **{tier.title()} Premium** to {user.mention} for {days} days!", + await ctx.send( + f"✅ Granted **{tier.title()} Premium** to **{ctx.guild.name}** for {days} days!", ephemeral=True ) else: - await interaction.response.send_message( - "❌ Failed to grant premium.", + await ctx.send("❌ Failed to grant premium.", ephemeral=True) + + # ------------------------------------------------------------------------- + # /premium revoke (admin) + # ------------------------------------------------------------------------- + + @premium_group.command(name="revoke", description="[ADMIN] Revoke premium from this server") + @commands.has_permissions(administrator=True) + async def premium_revoke(self, ctx: commands.Context): + """Revoke premium from this server (admin only)""" + if not ctx.guild: + await ctx.send("❌ This command must be used in a server.", ephemeral=True) + return + + success = await self.premium_model.revoke_premium(ctx.author.id, ctx.guild.id) + + if success: + await ctx.send(f"✅ Revoked premium from **{ctx.guild.name}**.", ephemeral=True) + else: + await ctx.send( + "❌ This server doesn't have active premium or revocation failed.", ephemeral=True ) - - @app_commands.command(name="premium_revoke", description="[ADMIN] Revoke premium from a user") - @app_commands.describe(user="User to revoke premium from") - @commands.has_permissions(administrator=True) - async def premium_revoke( + + # ------------------------------------------------------------------------- + # app_commands error handler for slash command permission errors + # ------------------------------------------------------------------------- + + async def cog_app_command_error( self, interaction: discord.Interaction, - user: discord.Member + error: app_commands.AppCommandError ): - """Revoke premium from a user (admin only)""" - success = await self.premium_model.revoke_premium( - user.id, - interaction.guild.id - ) - - if success: - await interaction.response.send_message( - f"✅ Revoked premium from {user.mention}.", - ephemeral=True - ) + if isinstance(error, app_commands.MissingPermissions): + if interaction.response.is_done(): + await interaction.followup.send( + "❌ You don't have permission to use this command.", ephemeral=True + ) + else: + await interaction.response.send_message( + "❌ You don't have permission to use this command.", ephemeral=True + ) + elif isinstance(error, app_commands.CheckFailure): + msg = str(error) or "❌ You don't meet the requirements for this command." + if interaction.response.is_done(): + await interaction.followup.send(msg, ephemeral=True) + else: + await interaction.response.send_message(msg, ephemeral=True) else: - await interaction.response.send_message( - "❌ User doesn't have active premium or revocation failed.", - ephemeral=True - ) + self.logger.error(f"Unhandled app command error in Premium cog: {error}", exc_info=error) + if not interaction.response.is_done(): + await interaction.response.send_message( + "❌ An error occurred while executing this command.", ephemeral=True + ) async def setup(bot): diff --git a/src/cogs/support/__init__.py b/src/cogs/support/__init__.py new file mode 100644 index 0000000..9303667 --- /dev/null +++ b/src/cogs/support/__init__.py @@ -0,0 +1,8 @@ +"""Support cogs module.""" + + +async def setup(bot): + """Setup function for support cogs.""" + from .tickets import Ticket + + await bot.add_cog(Ticket(bot)) diff --git a/src/cogs/utility/__init__.py b/src/cogs/utility/__init__.py index c975462..9a26048 100644 --- a/src/cogs/utility/__init__.py +++ b/src/cogs/utility/__init__.py @@ -1,19 +1,22 @@ """Utility cogs module.""" -# This allows loading the cogs as a package -from . import info, general, ai_chat, tickets, temp_channels, starboard, invites, interface, bump - async def setup(bot): """Setup function for loading all utility cogs.""" from .info import InfoUtility from .general import Utility - from .custom_commands_manager import CustomCommandsManager - + from .ping import PingCog + from .bump import Bump + from .invites import InviteTracker + from .interface import Interface + from .temp_channels import TempChannels + from .custom_status_manager import CustomStatusManager + await bot.add_cog(InfoUtility(bot)) await bot.add_cog(Utility(bot)) - await bot.add_cog(CustomCommandsManager(bot)) - - # Import and add PingCog - from .ping import PingCog await bot.add_cog(PingCog(bot)) + await bot.add_cog(Bump(bot)) + await bot.add_cog(InviteTracker(bot)) + await bot.add_cog(Interface(bot)) + await bot.add_cog(TempChannels(bot)) + await bot.add_cog(CustomStatusManager(bot)) diff --git a/src/database/models/premium.py b/src/database/models/premium.py index b4d811e..7ccb7bb 100644 --- a/src/database/models/premium.py +++ b/src/database/models/premium.py @@ -1,7 +1,7 @@ """Premium user database models and utilities.""" import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any from enum import Enum @@ -16,22 +16,23 @@ class PremiumTier(str, Enum): class PaymentMethod(str, Enum): - """Supported cryptocurrency payment methods""" + """Supported payment methods""" BTC = "BTC" ETH = "ETH" USDT = "USDT" USDC = "USDC" + STRIPE = "stripe" class PremiumModel: """Model for premium user data""" - + TIER_PRICES = { PremiumTier.BASIC: 5.0, # USD PremiumTier.PRO: 15.0, PremiumTier.ENTERPRISE: 50.0 } - + TIER_FEATURES = { PremiumTier.BASIC: { "enhanced_features": True, @@ -65,143 +66,163 @@ class PremiumModel: "custom_integrations": True } } - + def __init__(self, db_collection): """Initialize premium model with database collection - + Args: - db_collection: MongoDB collection for premium data + db_collection: Motor async MongoDB collection for premium data """ self.collection = db_collection - - async def get_user_premium(self, user_id: int, guild_id: int) -> Optional[Dict[str, Any]]: - """Get premium status for a user in a guild - + + async def get_guild_premium(self, guild_id: int) -> Optional[Dict[str, Any]]: + """Get premium status for a guild + Args: - user_id: Discord user ID guild_id: Discord guild ID - + Returns: Premium data dict or None if not premium """ try: - premium_data = self.collection.find_one({ - "user_id": str(user_id), + premium_data = await self.collection.find_one({ "guild_id": str(guild_id) }) - + if not premium_data: return None - - # Check if expired + if self.is_expired(premium_data): - logger.info(f"Premium expired for user {user_id} in guild {guild_id}") - # Mark as expired but don't delete (for renewal) - self.collection.update_one( - {"user_id": str(user_id), "guild_id": str(guild_id)}, + logger.info(f"Premium expired for guild {guild_id}") + await self.collection.update_one( + {"guild_id": str(guild_id)}, {"$set": {"status": "expired"}} ) return None - + return premium_data - + except Exception as e: - logger.error(f"Error getting premium status: {e}") + logger.error(f"Error getting guild premium status: {e}") return None - + + async def get_user_premium(self, user_id: int, guild_id: int) -> Optional[Dict[str, Any]]: + """Get premium status for a user in a guild (guild-based lookup) + + Args: + user_id: Discord user ID + guild_id: Discord guild ID + + Returns: + Premium data dict or None if not premium + """ + return await self.get_guild_premium(guild_id) + def is_expired(self, premium_data: Dict[str, Any]) -> bool: """Check if premium subscription is expired - + Args: premium_data: Premium data dictionary - + Returns: True if expired, False otherwise """ if not premium_data.get("expires_at"): return False - + try: expires_at = premium_data["expires_at"] if isinstance(expires_at, str): expires_at = datetime.fromisoformat(expires_at.replace('Z', '+00:00')) - - return datetime.now() > expires_at + + # Ensure timezone-aware comparison + if expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=timezone.utc) + + return datetime.now(timezone.utc) > expires_at except Exception as e: logger.error(f"Error checking expiry: {e}") return True - + async def has_feature(self, user_id: int, guild_id: int, feature: str) -> bool: - """Check if user has a specific premium feature - + """Check if a guild has a specific premium feature + Args: - user_id: Discord user ID + user_id: Discord user ID (unused, kept for API compatibility) guild_id: Discord guild ID feature: Feature name to check - + Returns: - True if user has the feature, False otherwise + True if guild has the feature, False otherwise """ - premium_data = await self.get_user_premium(user_id, guild_id) + premium_data = await self.get_guild_premium(guild_id) if not premium_data: return False - - tier = premium_data.get("premium_tier", "basic") + + tier_value = premium_data.get("premium_tier", "basic") + # Convert string to PremiumTier enum for lookup + try: + tier = PremiumTier(tier_value) + except ValueError: + tier = PremiumTier.BASIC features = self.TIER_FEATURES.get(tier, {}) - + return features.get(feature, False) - + async def activate_premium( self, user_id: int, guild_id: int, tier: PremiumTier, - payment_method: PaymentMethod, + payment_method: str, transaction_hash: str, duration_days: int = 30 ) -> bool: - """Activate premium for a user - + """Activate premium for a guild + Args: - user_id: Discord user ID + user_id: Discord user ID (admin/purchaser) guild_id: Discord guild ID - tier: Premium tier + tier: Premium tier (PremiumTier enum or string) payment_method: Payment method used - transaction_hash: Blockchain transaction hash + transaction_hash: Transaction hash or reference duration_days: Duration in days (default 30) - + Returns: True if successful, False otherwise """ try: - now = datetime.now() + now = datetime.now(timezone.utc) expires_at = now + timedelta(days=duration_days) - + + # Accept both enum and string for tier + tier_value = tier.value if isinstance(tier, PremiumTier) else str(tier) + payment_value = payment_method.value if isinstance(payment_method, PaymentMethod) else str(payment_method) + premium_data = { - "user_id": str(user_id), "guild_id": str(guild_id), - "premium_tier": tier, - "payment_method": payment_method, + "activated_by": str(user_id), + "premium_tier": tier_value, + "payment_method": payment_value, "transaction_hash": transaction_hash, "activated_at": now.isoformat(), "expires_at": expires_at.isoformat(), "status": "active" } - - # Upsert (update if exists, insert if not) - self.collection.update_one( - {"user_id": str(user_id), "guild_id": str(guild_id)}, + + await self.collection.update_one( + {"guild_id": str(guild_id)}, {"$set": premium_data}, upsert=True ) - - logger.info(f"Activated {tier} premium for user {user_id} in guild {guild_id}") + + logger.info(f"Activated {tier_value} premium for guild {guild_id} by user {user_id}") return True - + except Exception as e: logger.error(f"Error activating premium: {e}") return False - + async def extend_premium( self, user_id: int, @@ -209,60 +230,61 @@ async def extend_premium( additional_days: int = 30 ) -> bool: """Extend existing premium subscription - + Args: user_id: Discord user ID guild_id: Discord guild ID additional_days: Days to add (default 30) - + Returns: True if successful, False otherwise """ try: - premium_data = await self.get_user_premium(user_id, guild_id) + premium_data = await self.get_guild_premium(guild_id) if not premium_data: - logger.warning(f"No active premium found for user {user_id}") + logger.warning(f"No active premium found for guild {guild_id}") return False - - current_expiry = datetime.fromisoformat( - premium_data["expires_at"].replace('Z', '+00:00') - ) + + expires_str = premium_data["expires_at"] + current_expiry = datetime.fromisoformat(expires_str.replace('Z', '+00:00')) + if current_expiry.tzinfo is None: + current_expiry = current_expiry.replace(tzinfo=timezone.utc) new_expiry = current_expiry + timedelta(days=additional_days) - - self.collection.update_one( - {"user_id": str(user_id), "guild_id": str(guild_id)}, + + await self.collection.update_one( + {"guild_id": str(guild_id)}, {"$set": {"expires_at": new_expiry.isoformat(), "status": "active"}} ) - - logger.info(f"Extended premium for user {user_id} by {additional_days} days") + + logger.info(f"Extended premium for guild {guild_id} by {additional_days} days") return True - + except Exception as e: logger.error(f"Error extending premium: {e}") return False - + async def revoke_premium(self, user_id: int, guild_id: int) -> bool: - """Revoke premium access - + """Revoke premium access for a guild + Args: - user_id: Discord user ID + user_id: Discord user ID (unused, kept for API compatibility) guild_id: Discord guild ID - + Returns: True if successful, False otherwise """ try: - result = self.collection.update_one( - {"user_id": str(user_id), "guild_id": str(guild_id)}, - {"$set": {"status": "revoked", "expires_at": datetime.now().isoformat()}} + result = await self.collection.update_one( + {"guild_id": str(guild_id)}, + {"$set": {"status": "revoked", "expires_at": datetime.now(timezone.utc).isoformat()}} ) - + if result.modified_count > 0: - logger.info(f"Revoked premium for user {user_id} in guild {guild_id}") + logger.info(f"Revoked premium for guild {guild_id}") return True - + return False - + except Exception as e: logger.error(f"Error revoking premium: {e}") return False diff --git a/src/services/__init__.py b/src/services/__init__.py index 418fec7..21f9930 100644 --- a/src/services/__init__.py +++ b/src/services/__init__.py @@ -10,11 +10,14 @@ from .user_service import UserService from .giveaway_service import GiveawayService from .moderation_service import ModerationService +from .payment import CryptoVerificationService, StripeService __all__ = [ 'BaseService', 'GuildService', 'UserService', 'GiveawayService', - 'ModerationService' + 'ModerationService', + 'CryptoVerificationService', + 'StripeService', ] diff --git a/src/services/crypto_verification.py b/src/services/crypto_verification.py index c9bf3de..5403dcc 100644 --- a/src/services/crypto_verification.py +++ b/src/services/crypto_verification.py @@ -1,305 +1,7 @@ -"""Crypto payment verification service for premium subscriptions.""" +"""Crypto verification service — backward-compatibility re-export. -import logging -import aiohttp -from typing import Optional, Dict, Any -from enum import Enum +The canonical implementation lives in src/services/payment/crypto.py. +""" +from .payment.crypto import CryptoVerificationService, CryptoNetwork -logger = logging.getLogger('crypto_verification') - - -class CryptoNetwork(str, Enum): - """Supported blockchain networks""" - BITCOIN = "bitcoin" - ETHEREUM = "ethereum" - TRON = "tron" # For USDT - - -class CryptoVerificationService: - """Service to verify cryptocurrency transactions""" - - # Blockchain explorer APIs - EXPLORERS = { - CryptoNetwork.BITCOIN: "https://blockchair.com/bitcoin/transaction/{}", - CryptoNetwork.ETHEREUM: "https://api.etherscan.io/api?module=transaction&action=gettxreceiptstatus&txhash={}&apikey=YourApiKeyToken", - CryptoNetwork.TRON: "https://apilist.tronscan.org/api/transaction-info?hash={}" - } - - # Payment addresses for each cryptocurrency - PAYMENT_ADDRESSES = { - "BTC": "bc1q...", # TODO: Add actual Bitcoin address - "ETH": "0x...", # TODO: Add actual Ethereum address - "USDT": "TR...", # TODO: Add actual USDT (TRC20) address - "USDC": "0x..." # TODO: Add actual USDC (ERC20) address - } - - # Minimum confirmations required - MIN_CONFIRMATIONS = { - "BTC": 3, - "ETH": 12, - "USDT": 19, - "USDC": 12 - } - - def __init__(self, api_keys: Optional[Dict[str, str]] = None): - """Initialize crypto verification service - - Args: - api_keys: Optional dict of API keys for blockchain explorers - """ - self.api_keys = api_keys or {} - - async def verify_transaction( - self, - crypto: str, - tx_hash: str, - expected_amount: float, - recipient_address: Optional[str] = None - ) -> Dict[str, Any]: - """Verify a cryptocurrency transaction - - Args: - crypto: Cryptocurrency code (BTC, ETH, USDT, USDC) - tx_hash: Transaction hash - expected_amount: Expected payment amount in USD - recipient_address: Optional recipient address to verify - - Returns: - Dict with verification result - """ - try: - if crypto == "BTC": - return await self._verify_bitcoin(tx_hash, expected_amount, recipient_address) - elif crypto == "ETH": - return await self._verify_ethereum(tx_hash, expected_amount, recipient_address) - elif crypto in ["USDT", "USDC"]: - return await self._verify_stablecoin(crypto, tx_hash, expected_amount, recipient_address) - else: - return { - "verified": False, - "error": f"Unsupported cryptocurrency: {crypto}" - } - except Exception as e: - logger.error(f"Error verifying transaction: {e}") - return { - "verified": False, - "error": str(e) - } - - async def _verify_bitcoin( - self, - tx_hash: str, - expected_amount_usd: float, - recipient_address: Optional[str] = None - ) -> Dict[str, Any]: - """Verify Bitcoin transaction - - Args: - tx_hash: Bitcoin transaction hash - expected_amount_usd: Expected amount in USD - recipient_address: Optional recipient address - - Returns: - Verification result dict - """ - try: - # Use Blockchair API - url = f"https://api.blockchair.com/bitcoin/dashboards/transaction/{tx_hash}" - - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status != 200: - return { - "verified": False, - "error": f"API error: {response.status}" - } - - data = await response.json() - - if "data" not in data or tx_hash not in data["data"]: - return { - "verified": False, - "error": "Transaction not found" - } - - tx_data = data["data"][tx_hash]["transaction"] - - # Get BTC price to convert amount - btc_price = await self._get_crypto_price("BTC") - expected_amount_btc = expected_amount_usd / btc_price - - # Verify confirmations - confirmations = tx_data.get("confirmations", 0) - if confirmations < self.MIN_CONFIRMATIONS["BTC"]: - return { - "verified": False, - "error": f"Insufficient confirmations: {confirmations}/{self.MIN_CONFIRMATIONS['BTC']}" - } - - # Verify amount (with 5% tolerance for fees and price fluctuation) - output_total_btc = tx_data.get("output_total", 0) / 100000000 # Satoshi to BTC - - if output_total_btc < expected_amount_btc * 0.95: - return { - "verified": False, - "error": f"Amount mismatch: {output_total_btc} BTC (expected ~{expected_amount_btc} BTC)" - } - - # TODO: Verify recipient address if provided - - return { - "verified": True, - "amount_btc": output_total_btc, - "amount_usd": output_total_btc * btc_price, - "confirmations": confirmations, - "tx_hash": tx_hash - } - - except Exception as e: - logger.error(f"Error verifying Bitcoin transaction: {e}") - return { - "verified": False, - "error": str(e) - } - - async def _verify_ethereum( - self, - tx_hash: str, - expected_amount_usd: float, - recipient_address: Optional[str] = None - ) -> Dict[str, Any]: - """Verify Ethereum transaction - - Args: - tx_hash: Ethereum transaction hash - expected_amount_usd: Expected amount in USD - recipient_address: Optional recipient address - - Returns: - Verification result dict - """ - # TODO: Implement Ethereum verification using Etherscan API - return { - "verified": False, - "error": "Ethereum verification not yet implemented" - } - - async def _verify_stablecoin( - self, - crypto: str, - tx_hash: str, - expected_amount_usd: float, - recipient_address: Optional[str] = None - ) -> Dict[str, Any]: - """Verify stablecoin (USDT, USDC) transaction - - Args: - crypto: Stablecoin type (USDT or USDC) - tx_hash: Transaction hash - expected_amount_usd: Expected amount in USD - recipient_address: Optional recipient address - - Returns: - Verification result dict - """ - # TODO: Implement stablecoin verification - # USDT: Use Tron API or Etherscan API depending on network - # USDC: Use Etherscan API - return { - "verified": False, - "error": f"{crypto} verification not yet implemented" - } - - async def _get_crypto_price(self, crypto: str) -> float: - """Get current crypto price in USD - - Args: - crypto: Cryptocurrency code - - Returns: - Price in USD - """ - try: - # Use CoinGecko API (free, no API key required) - crypto_ids = { - "BTC": "bitcoin", - "ETH": "ethereum", - "USDT": "tether", - "USDC": "usd-coin" - } - - if crypto not in crypto_ids: - raise ValueError(f"Unknown crypto: {crypto}") - - url = f"https://api.coingecko.com/api/v3/simple/price?ids={crypto_ids[crypto]}&vs_currencies=usd" - - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status != 200: - raise Exception(f"CoinGecko API error: {response.status}") - - data = await response.json() - return data[crypto_ids[crypto]]["usd"] - - except Exception as e: - logger.error(f"Error getting crypto price: {e}") - # Fallback prices - fallback_prices = { - "BTC": 45000.0, - "ETH": 2500.0, - "USDT": 1.0, - "USDC": 1.0 - } - return fallback_prices.get(crypto, 0.0) - - def get_payment_address(self, crypto: str) -> Optional[str]: - """Get payment address for a cryptocurrency - - Args: - crypto: Cryptocurrency code - - Returns: - Payment address or None - """ - return self.PAYMENT_ADDRESSES.get(crypto) - - def get_payment_info(self, crypto: str, amount_usd: float) -> Dict[str, Any]: - """Get payment information for a cryptocurrency - - Args: - crypto: Cryptocurrency code - amount_usd: Amount in USD - - Returns: - Payment info dict - """ - address = self.get_payment_address(crypto) - if not address: - return { - "error": f"No payment address configured for {crypto}" - } - - return { - "crypto": crypto, - "address": address, - "amount_usd": amount_usd, - "min_confirmations": self.MIN_CONFIRMATIONS.get(crypto, 10), - "network": self._get_network_name(crypto) - } - - def _get_network_name(self, crypto: str) -> str: - """Get network name for a cryptocurrency - - Args: - crypto: Cryptocurrency code - - Returns: - Network name - """ - networks = { - "BTC": "Bitcoin Mainnet", - "ETH": "Ethereum Mainnet", - "USDT": "Tron (TRC20)", - "USDC": "Ethereum (ERC20)" - } - return networks.get(crypto, "Unknown") +__all__ = ["CryptoVerificationService", "CryptoNetwork"] diff --git a/src/services/guild.py b/src/services/guild.py deleted file mode 100644 index f789f61..0000000 --- a/src/services/guild.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Guild service for business logic.""" -from typing import Optional, List, Dict, Any -from .base import BaseService -from ..database.collections import GuildCollection -from ..database.models import Guild -import discord - - -class GuildService(BaseService): - """Service for guild-related operations.""" - - def __init__(self, db): - """Initialize guild service.""" - super().__init__(db) - self.guild_collection = GuildCollection(db) - - async def get_guild(self, guild_id: int) -> Optional[Guild]: - """Get guild by ID.""" - try: - return await self.guild_collection.find_by_guild_id(guild_id) - except Exception as e: - self.log_error(f"Error getting guild {guild_id}", exc_info=e) - return None - - async def ensure_guild_exists(self, discord_guild: discord.Guild) -> Guild: - """Ensure guild exists in database.""" - try: - return await self.guild_collection.get_or_create( - guild_id=discord_guild.id, - name=discord_guild.name - ) - except Exception as e: - self.log_error(f"Error ensuring guild exists: {discord_guild.id}", exc_info=e) - raise - - async def update_guild_info(self, discord_guild: discord.Guild) -> bool: - """Update guild information from Discord.""" - try: - return await self.guild_collection.update_one( - filter={'guild_id': discord_guild.id}, - update={'$set': { - 'name': discord_guild.name, - 'member_count': discord_guild.member_count, - 'icon_url': str(discord_guild.icon.url) if discord_guild.icon else None, - 'owner_id': discord_guild.owner_id - }} - ) - except Exception as e: - self.log_error(f"Error updating guild info: {discord_guild.id}", exc_info=e) - return False - - async def get_prefix(self, guild_id: int) -> str: - """Get guild prefix.""" - guild = await self.get_guild(guild_id) - return guild.prefix if guild else '!' - - async def set_prefix(self, guild_id: int, prefix: str) -> bool: - """Set guild prefix.""" - try: - if len(prefix) > 5: - raise ValueError("Prefix too long (max 5 characters)") - - success = await self.guild_collection.update_prefix(guild_id, prefix) - if success: - self.log_info(f"Updated prefix for guild {guild_id} to '{prefix}'") - return success - except Exception as e: - self.log_error(f"Error setting prefix for guild {guild_id}", exc_info=e) - return False - - async def get_language(self, guild_id: int) -> str: - """Get guild language.""" - guild = await self.get_guild(guild_id) - return guild.language if guild else 'en' - - async def set_language(self, guild_id: int, language: str) -> bool: - """Set guild language.""" - try: - if language not in ['en', 'tr']: # Add more languages as needed - raise ValueError(f"Unsupported language: {language}") - - return await self.guild_collection.update_language(guild_id, language) - except Exception as e: - self.log_error(f"Error setting language for guild {guild_id}", exc_info=e) - return False - - async def get_setting(self, guild_id: int, key: str, default: Any = None) -> Any: - """Get a guild setting.""" - guild = await self.get_guild(guild_id) - return guild.get_setting(key, default) if guild else default - - async def set_setting(self, guild_id: int, key: str, value: Any) -> bool: - """Set a guild setting.""" - try: - return await self.guild_collection.update_setting(guild_id, key, value) - except Exception as e: - self.log_error(f"Error setting {key} for guild {guild_id}", exc_info=e) - return False - - async def has_feature(self, guild_id: int, feature: str) -> bool: - """Check if guild has a feature.""" - guild = await self.get_guild(guild_id) - return guild.has_feature(feature) if guild else False - - async def add_feature(self, guild_id: int, feature: str) -> bool: - """Add a feature to guild.""" - try: - return await self.guild_collection.add_feature(guild_id, feature) - except Exception as e: - self.log_error(f"Error adding feature {feature} to guild {guild_id}", exc_info=e) - return False - - async def remove_feature(self, guild_id: int, feature: str) -> bool: - """Remove a feature from guild.""" - try: - return await self.guild_collection.remove_feature(guild_id, feature) - except Exception as e: - self.log_error(f"Error removing feature {feature} from guild {guild_id}", exc_info=e) - return False - - async def get_active_guilds(self, days: int = 30) -> List[Guild]: - """Get recently active guilds.""" - try: - return await self.guild_collection.get_active_guilds(days) - except Exception as e: - self.log_error("Error getting active guilds", exc_info=e) - return [] - - async def get_stats(self) -> Dict[str, int]: - """Get guild statistics.""" - try: - total = await self.guild_collection.count() - active = len(await self.get_active_guilds(7)) - - return { - 'total_guilds': total, - 'active_guilds_week': active, - 'inactive_guilds': total - active - } - except Exception as e: - self.log_error("Error getting guild stats", exc_info=e) - return { - 'total_guilds': 0, - 'active_guilds_week': 0, - 'inactive_guilds': 0 - } \ No newline at end of file diff --git a/src/services/payment/__init__.py b/src/services/payment/__init__.py new file mode 100644 index 0000000..b3eb40e --- /dev/null +++ b/src/services/payment/__init__.py @@ -0,0 +1,14 @@ +""" +Payment services package. + +Provides both cryptocurrency and Stripe card payment verification +and processing for the premium subscription system. +""" +from .crypto import CryptoVerificationService, CryptoNetwork +from .stripe import StripeService + +__all__ = [ + "CryptoVerificationService", + "CryptoNetwork", + "StripeService", +] diff --git a/src/services/payment/crypto.py b/src/services/payment/crypto.py new file mode 100644 index 0000000..c9bf3de --- /dev/null +++ b/src/services/payment/crypto.py @@ -0,0 +1,305 @@ +"""Crypto payment verification service for premium subscriptions.""" + +import logging +import aiohttp +from typing import Optional, Dict, Any +from enum import Enum + +logger = logging.getLogger('crypto_verification') + + +class CryptoNetwork(str, Enum): + """Supported blockchain networks""" + BITCOIN = "bitcoin" + ETHEREUM = "ethereum" + TRON = "tron" # For USDT + + +class CryptoVerificationService: + """Service to verify cryptocurrency transactions""" + + # Blockchain explorer APIs + EXPLORERS = { + CryptoNetwork.BITCOIN: "https://blockchair.com/bitcoin/transaction/{}", + CryptoNetwork.ETHEREUM: "https://api.etherscan.io/api?module=transaction&action=gettxreceiptstatus&txhash={}&apikey=YourApiKeyToken", + CryptoNetwork.TRON: "https://apilist.tronscan.org/api/transaction-info?hash={}" + } + + # Payment addresses for each cryptocurrency + PAYMENT_ADDRESSES = { + "BTC": "bc1q...", # TODO: Add actual Bitcoin address + "ETH": "0x...", # TODO: Add actual Ethereum address + "USDT": "TR...", # TODO: Add actual USDT (TRC20) address + "USDC": "0x..." # TODO: Add actual USDC (ERC20) address + } + + # Minimum confirmations required + MIN_CONFIRMATIONS = { + "BTC": 3, + "ETH": 12, + "USDT": 19, + "USDC": 12 + } + + def __init__(self, api_keys: Optional[Dict[str, str]] = None): + """Initialize crypto verification service + + Args: + api_keys: Optional dict of API keys for blockchain explorers + """ + self.api_keys = api_keys or {} + + async def verify_transaction( + self, + crypto: str, + tx_hash: str, + expected_amount: float, + recipient_address: Optional[str] = None + ) -> Dict[str, Any]: + """Verify a cryptocurrency transaction + + Args: + crypto: Cryptocurrency code (BTC, ETH, USDT, USDC) + tx_hash: Transaction hash + expected_amount: Expected payment amount in USD + recipient_address: Optional recipient address to verify + + Returns: + Dict with verification result + """ + try: + if crypto == "BTC": + return await self._verify_bitcoin(tx_hash, expected_amount, recipient_address) + elif crypto == "ETH": + return await self._verify_ethereum(tx_hash, expected_amount, recipient_address) + elif crypto in ["USDT", "USDC"]: + return await self._verify_stablecoin(crypto, tx_hash, expected_amount, recipient_address) + else: + return { + "verified": False, + "error": f"Unsupported cryptocurrency: {crypto}" + } + except Exception as e: + logger.error(f"Error verifying transaction: {e}") + return { + "verified": False, + "error": str(e) + } + + async def _verify_bitcoin( + self, + tx_hash: str, + expected_amount_usd: float, + recipient_address: Optional[str] = None + ) -> Dict[str, Any]: + """Verify Bitcoin transaction + + Args: + tx_hash: Bitcoin transaction hash + expected_amount_usd: Expected amount in USD + recipient_address: Optional recipient address + + Returns: + Verification result dict + """ + try: + # Use Blockchair API + url = f"https://api.blockchair.com/bitcoin/dashboards/transaction/{tx_hash}" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status != 200: + return { + "verified": False, + "error": f"API error: {response.status}" + } + + data = await response.json() + + if "data" not in data or tx_hash not in data["data"]: + return { + "verified": False, + "error": "Transaction not found" + } + + tx_data = data["data"][tx_hash]["transaction"] + + # Get BTC price to convert amount + btc_price = await self._get_crypto_price("BTC") + expected_amount_btc = expected_amount_usd / btc_price + + # Verify confirmations + confirmations = tx_data.get("confirmations", 0) + if confirmations < self.MIN_CONFIRMATIONS["BTC"]: + return { + "verified": False, + "error": f"Insufficient confirmations: {confirmations}/{self.MIN_CONFIRMATIONS['BTC']}" + } + + # Verify amount (with 5% tolerance for fees and price fluctuation) + output_total_btc = tx_data.get("output_total", 0) / 100000000 # Satoshi to BTC + + if output_total_btc < expected_amount_btc * 0.95: + return { + "verified": False, + "error": f"Amount mismatch: {output_total_btc} BTC (expected ~{expected_amount_btc} BTC)" + } + + # TODO: Verify recipient address if provided + + return { + "verified": True, + "amount_btc": output_total_btc, + "amount_usd": output_total_btc * btc_price, + "confirmations": confirmations, + "tx_hash": tx_hash + } + + except Exception as e: + logger.error(f"Error verifying Bitcoin transaction: {e}") + return { + "verified": False, + "error": str(e) + } + + async def _verify_ethereum( + self, + tx_hash: str, + expected_amount_usd: float, + recipient_address: Optional[str] = None + ) -> Dict[str, Any]: + """Verify Ethereum transaction + + Args: + tx_hash: Ethereum transaction hash + expected_amount_usd: Expected amount in USD + recipient_address: Optional recipient address + + Returns: + Verification result dict + """ + # TODO: Implement Ethereum verification using Etherscan API + return { + "verified": False, + "error": "Ethereum verification not yet implemented" + } + + async def _verify_stablecoin( + self, + crypto: str, + tx_hash: str, + expected_amount_usd: float, + recipient_address: Optional[str] = None + ) -> Dict[str, Any]: + """Verify stablecoin (USDT, USDC) transaction + + Args: + crypto: Stablecoin type (USDT or USDC) + tx_hash: Transaction hash + expected_amount_usd: Expected amount in USD + recipient_address: Optional recipient address + + Returns: + Verification result dict + """ + # TODO: Implement stablecoin verification + # USDT: Use Tron API or Etherscan API depending on network + # USDC: Use Etherscan API + return { + "verified": False, + "error": f"{crypto} verification not yet implemented" + } + + async def _get_crypto_price(self, crypto: str) -> float: + """Get current crypto price in USD + + Args: + crypto: Cryptocurrency code + + Returns: + Price in USD + """ + try: + # Use CoinGecko API (free, no API key required) + crypto_ids = { + "BTC": "bitcoin", + "ETH": "ethereum", + "USDT": "tether", + "USDC": "usd-coin" + } + + if crypto not in crypto_ids: + raise ValueError(f"Unknown crypto: {crypto}") + + url = f"https://api.coingecko.com/api/v3/simple/price?ids={crypto_ids[crypto]}&vs_currencies=usd" + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status != 200: + raise Exception(f"CoinGecko API error: {response.status}") + + data = await response.json() + return data[crypto_ids[crypto]]["usd"] + + except Exception as e: + logger.error(f"Error getting crypto price: {e}") + # Fallback prices + fallback_prices = { + "BTC": 45000.0, + "ETH": 2500.0, + "USDT": 1.0, + "USDC": 1.0 + } + return fallback_prices.get(crypto, 0.0) + + def get_payment_address(self, crypto: str) -> Optional[str]: + """Get payment address for a cryptocurrency + + Args: + crypto: Cryptocurrency code + + Returns: + Payment address or None + """ + return self.PAYMENT_ADDRESSES.get(crypto) + + def get_payment_info(self, crypto: str, amount_usd: float) -> Dict[str, Any]: + """Get payment information for a cryptocurrency + + Args: + crypto: Cryptocurrency code + amount_usd: Amount in USD + + Returns: + Payment info dict + """ + address = self.get_payment_address(crypto) + if not address: + return { + "error": f"No payment address configured for {crypto}" + } + + return { + "crypto": crypto, + "address": address, + "amount_usd": amount_usd, + "min_confirmations": self.MIN_CONFIRMATIONS.get(crypto, 10), + "network": self._get_network_name(crypto) + } + + def _get_network_name(self, crypto: str) -> str: + """Get network name for a cryptocurrency + + Args: + crypto: Cryptocurrency code + + Returns: + Network name + """ + networks = { + "BTC": "Bitcoin Mainnet", + "ETH": "Ethereum Mainnet", + "USDT": "Tron (TRC20)", + "USDC": "Ethereum (ERC20)" + } + return networks.get(crypto, "Unknown") diff --git a/src/services/payment/stripe.py b/src/services/payment/stripe.py new file mode 100644 index 0000000..49fd18d --- /dev/null +++ b/src/services/payment/stripe.py @@ -0,0 +1,193 @@ +"""Stripe payment service for premium subscriptions.""" + +import logging +import os +from typing import Optional, Dict, Any + +logger = logging.getLogger('stripe_service') + + +class StripeService: + """Service to handle Stripe payment operations for premium subscriptions.""" + + TIER_PRICES_USD = { + "basic": 500, # $5.00 in cents + "pro": 1500, # $15.00 in cents + "enterprise": 5000 # $50.00 in cents + } + + def __init__(self): + """Initialize Stripe service with API key from environment.""" + import stripe as stripe_lib + self._stripe = stripe_lib + self._stripe.api_key = os.getenv("STRIPE_SECRET_KEY", "") + self.webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET", "") + + if not self._stripe.api_key: + logger.warning("STRIPE_SECRET_KEY not configured") + + def create_checkout_session( + self, + guild_id: int, + user_id: int, + tier: str, + success_url: str, + cancel_url: str, + duration_days: int = 30 + ) -> Dict[str, Any]: + """Create a Stripe Checkout Session for premium purchase. + + Args: + guild_id: Discord guild ID + user_id: Discord user ID (purchaser) + tier: Premium tier (basic, pro, enterprise) + success_url: Redirect URL on successful payment + cancel_url: Redirect URL on cancelled payment + duration_days: Subscription duration in days + + Returns: + Dict with session_id and checkout_url, or error + """ + try: + amount_cents = self.TIER_PRICES_USD.get(tier.lower()) + if amount_cents is None: + return {"error": f"Unknown tier: {tier}"} + + session = self._stripe.checkout.Session.create( + payment_method_types=["card"], + line_items=[{ + "price_data": { + "currency": "usd", + "unit_amount": amount_cents, + "product_data": { + "name": f"Contro Bot - {tier.title()} Premium", + "description": f"{duration_days}-day premium subscription for your Discord server" + } + }, + "quantity": 1 + }], + mode="payment", + success_url=success_url, + cancel_url=cancel_url, + metadata={ + "guild_id": str(guild_id), + "user_id": str(user_id), + "tier": tier.lower(), + "duration_days": str(duration_days) + } + ) + + return { + "session_id": session.id, + "checkout_url": session.url + } + + except Exception as e: + logger.error(f"Error creating Stripe checkout session: {e}") + return {"error": str(e)} + + def create_payment_intent( + self, + guild_id: int, + user_id: int, + tier: str, + duration_days: int = 30 + ) -> Dict[str, Any]: + """Create a Stripe PaymentIntent for premium purchase. + + Args: + guild_id: Discord guild ID + user_id: Discord user ID (purchaser) + tier: Premium tier (basic, pro, enterprise) + duration_days: Subscription duration in days + + Returns: + Dict with client_secret and payment_intent_id, or error + """ + try: + amount_cents = self.TIER_PRICES_USD.get(tier.lower()) + if amount_cents is None: + return {"error": f"Unknown tier: {tier}"} + + intent = self._stripe.PaymentIntent.create( + amount=amount_cents, + currency="usd", + metadata={ + "guild_id": str(guild_id), + "user_id": str(user_id), + "tier": tier.lower(), + "duration_days": str(duration_days) + }, + description=f"Contro Bot {tier.title()} Premium - Guild {guild_id}" + ) + + return { + "payment_intent_id": intent.id, + "client_secret": intent.client_secret, + "amount_usd": amount_cents / 100 + } + + except Exception as e: + logger.error(f"Error creating Stripe payment intent: {e}") + return {"error": str(e)} + + def verify_webhook(self, payload: bytes, sig_header: str) -> Optional[Dict[str, Any]]: + """Verify and parse a Stripe webhook event. + + Args: + payload: Raw request body bytes + sig_header: Stripe-Signature header value + + Returns: + Parsed Stripe event dict, or None if verification fails + """ + try: + event = self._stripe.Webhook.construct_event( + payload, sig_header, self.webhook_secret + ) + return event + except self._stripe.error.SignatureVerificationError as e: + logger.error(f"Stripe webhook signature verification failed: {e}") + return None + except Exception as e: + logger.error(f"Error verifying Stripe webhook: {e}") + return None + + def get_payment_intent(self, payment_intent_id: str) -> Optional[Dict[str, Any]]: + """Retrieve a Stripe PaymentIntent by ID. + + Args: + payment_intent_id: Stripe PaymentIntent ID + + Returns: + PaymentIntent data dict, or None on error + """ + try: + intent = self._stripe.PaymentIntent.retrieve(payment_intent_id) + return { + "id": intent.id, + "status": intent.status, + "amount": intent.amount, + "currency": intent.currency, + "metadata": dict(intent.metadata), + "succeeded": intent.status == "succeeded" + } + except Exception as e: + logger.error(f"Error retrieving Stripe payment intent: {e}") + return None + + def cancel_payment_intent(self, payment_intent_id: str) -> bool: + """Cancel a Stripe PaymentIntent. + + Args: + payment_intent_id: Stripe PaymentIntent ID + + Returns: + True if cancelled, False on error + """ + try: + self._stripe.PaymentIntent.cancel(payment_intent_id) + return True + except Exception as e: + logger.error(f"Error cancelling Stripe payment intent: {e}") + return False diff --git a/src/services/stripe_service.py b/src/services/stripe_service.py new file mode 100644 index 0000000..c0ca926 --- /dev/null +++ b/src/services/stripe_service.py @@ -0,0 +1,7 @@ +"""Stripe payment service — backward-compatibility re-export. + +The canonical implementation lives in src/services/payment/stripe.py. +""" +from .payment.stripe import StripeService + +__all__ = ["StripeService"] diff --git a/src/services/user.py b/src/services/user.py deleted file mode 100644 index e7cc59c..0000000 --- a/src/services/user.py +++ /dev/null @@ -1,112 +0,0 @@ -"""User service for business logic.""" -from typing import Optional, List, Dict, Any -from .base import BaseService -from ..database.collections import UserCollection -from ..database.models import User -import discord - - -class UserService(BaseService): - """Service for user-related operations.""" - - def __init__(self, db): - """Initialize user service.""" - super().__init__(db) - self.user_collection = UserCollection(db) - - async def get_user(self, user_id: int) -> Optional[User]: - """Get user by ID.""" - try: - return await self.user_collection.find_by_user_id(user_id) - except Exception as e: - self.log_error(f"Error getting user {user_id}", exc_info=e) - return None - - async def ensure_user_exists(self, discord_user: discord.User) -> User: - """Ensure user exists in database.""" - try: - return await self.user_collection.get_or_create( - user_id=discord_user.id, - username=discord_user.name, - discriminator=discord_user.discriminator or "0" - ) - except Exception as e: - self.log_error(f"Error ensuring user exists: {discord_user.id}", exc_info=e) - raise - - async def add_xp(self, user_id: int, amount: int) -> Optional[tuple[int, int]]: - """Add XP to user and return new XP and level.""" - try: - result = await self.user_collection.add_global_xp(user_id, amount) - if result: - new_xp, new_level = result - self.log_info(f"Added {amount} XP to user {user_id}. New: {new_xp} XP, Level {new_level}") - return result - except Exception as e: - self.log_error(f"Error adding XP to user {user_id}", exc_info=e) - return None - - async def get_leaderboard(self, limit: int = 10) -> List[User]: - """Get XP leaderboard.""" - try: - return await self.user_collection.get_top_users_by_xp(limit) - except Exception as e: - self.log_error("Error getting leaderboard", exc_info=e) - return [] - - async def has_badge(self, user_id: int, badge: str) -> bool: - """Check if user has a specific badge.""" - user = await self.get_user(user_id) - return user.has_badge(badge) if user else False - - async def add_badge(self, user_id: int, badge: str) -> bool: - """Add a badge to user.""" - try: - success = await self.user_collection.add_badge(user_id, badge) - if success: - self.log_info(f"Added badge '{badge}' to user {user_id}") - return success - except Exception as e: - self.log_error(f"Error adding badge to user {user_id}", exc_info=e) - return False - - async def remove_badge(self, user_id: int, badge: str) -> bool: - """Remove a badge from user.""" - try: - success = await self.user_collection.remove_badge(user_id, badge) - if success: - self.log_info(f"Removed badge '{badge}' from user {user_id}") - return success - except Exception as e: - self.log_error(f"Error removing badge from user {user_id}", exc_info=e) - return False - - async def get_users_by_badge(self, badge: str) -> List[User]: - """Get all users with a specific badge.""" - try: - return await self.user_collection.get_users_by_badge(badge) - except Exception as e: - self.log_error(f"Error getting users with badge '{badge}'", exc_info=e) - return [] - - async def get_user_stats(self, user_id: int) -> Dict[str, Any]: - """Get user statistics.""" - user = await self.get_user(user_id) - if not user: - return { - 'exists': False, - 'xp': 0, - 'level': 0, - 'badges': [], - 'badge_count': 0 - } - - return { - 'exists': True, - 'xp': user.global_xp, - 'level': user.global_level, - 'badges': user.badges, - 'badge_count': len(user.badges), - 'next_level_xp': User.xp_for_next_level(user.global_level), - 'created_at': user.created_at - } \ No newline at end of file diff --git a/src/utils/class_utils/__init__.py b/src/utils/class_utils/__init__.py index fd56c91..3776115 100644 --- a/src/utils/class_utils/__init__.py +++ b/src/utils/class_utils/__init__.py @@ -1,194 +1,6 @@ """ Utility classes for the Contro bot. """ -import discord -from discord.ext import commands -import asyncio -from typing import List, Optional, Any, Union, Callable, Dict +from .paginator import Paginator -class Paginator(discord.ui.View): - """ - A paginator view for navigating through multiple pages of content. - """ - def __init__( - self, - pages: List[Union[discord.Embed, str]], - timeout: int = 60, - author_id: Optional[int] = None, - start_page: int = 0 - ): - """ - Initialize the paginator. - - Args: - pages: List of embeds or strings to paginate - timeout: Seconds before the paginator times out - author_id: User ID who can interact with the paginator - start_page: The initial page to display - """ - super().__init__(timeout=timeout) - self.pages = pages - self.author_id = author_id - self.current_page = start_page - self.max_pages = len(pages) - - # Disable buttons if needed - self.update_buttons() - - def update_buttons(self): - """Update the state of navigation buttons based on current page.""" - # First page button - self.first_page_button.disabled = (self.current_page == 0) - - # Previous page button - self.prev_button.disabled = (self.current_page == 0) - - # Page indicator (not a button, just for display) - self.page_indicator.label = f"{self.current_page + 1}/{self.max_pages}" - - # Next page button - self.next_button.disabled = (self.current_page == self.max_pages - 1) - - # Last page button - self.last_page_button.disabled = (self.current_page == self.max_pages - 1) - - async def interaction_check(self, interaction: discord.Interaction) -> bool: - """ - Check if the user is allowed to interact with the paginator. - - Args: - interaction: The interaction to check - - Returns: - bool: Whether the interaction is allowed - """ - if self.author_id is None: - return True - - if interaction.user.id == self.author_id: - return True - - await interaction.response.send_message( - "You cannot control this pagination because it was not started by you.", - ephemeral=True - ) - return False - - async def on_timeout(self) -> None: - """Disable all buttons when the view times out.""" - for item in self.children: - item.disabled = True - - @discord.ui.button(emoji="⏮️", style=discord.ButtonStyle.gray) - async def first_page_button(self, interaction: discord.Interaction, button: discord.ui.Button): - """Go to the first page.""" - self.current_page = 0 - self.update_buttons() - await self.show_current_page(interaction) - - @discord.ui.button(emoji="◀️", style=discord.ButtonStyle.blurple) - async def prev_button(self, interaction: discord.Interaction, button: discord.ui.Button): - """Go to the previous page.""" - self.current_page = max(0, self.current_page - 1) - self.update_buttons() - await self.show_current_page(interaction) - - @discord.ui.button(label="1/1", style=discord.ButtonStyle.gray, disabled=True) - async def page_indicator(self, interaction: discord.Interaction, button: discord.ui.Button): - """Page indicator, not functional as a button.""" - pass - - @discord.ui.button(emoji="▶️", style=discord.ButtonStyle.blurple) - async def next_button(self, interaction: discord.Interaction, button: discord.ui.Button): - """Go to the next page.""" - self.current_page = min(self.max_pages - 1, self.current_page + 1) - self.update_buttons() - await self.show_current_page(interaction) - - @discord.ui.button(emoji="⏭️", style=discord.ButtonStyle.gray) - async def last_page_button(self, interaction: discord.Interaction, button: discord.ui.Button): - """Go to the last page.""" - self.current_page = self.max_pages - 1 - self.update_buttons() - await self.show_current_page(interaction) - - async def show_current_page(self, interaction: discord.Interaction): - """ - Show the current page to the user. - - Args: - interaction: The interaction to respond to - """ - current_content = self.pages[self.current_page] - - kwargs = {} - if isinstance(current_content, discord.Embed): - kwargs["embed"] = current_content - kwargs["content"] = None - else: - kwargs["content"] = str(current_content) - kwargs["embed"] = None - - kwargs["view"] = self - - await interaction.response.edit_message(**kwargs) - - @classmethod - async def create_paginator( - cls, - ctx_or_interaction: Union[commands.Context, discord.Interaction], - pages: List[Union[discord.Embed, str]], - timeout: int = 60, - author_id: Optional[int] = None, - start_page: int = 0, - ephemeral: bool = False - ): - """ - Create and start a paginator. - - Args: - ctx_or_interaction: Context or Interaction to respond to - pages: List of embeds or strings to paginate - timeout: Seconds before the paginator times out - author_id: User ID who can interact with the paginator - start_page: The initial page to display - ephemeral: Whether the response should be ephemeral (only works with interactions) - - Returns: - Paginator: The created paginator instance - """ - if not pages: - raise ValueError("Cannot create a paginator with no pages") - - # Set author_id from context if not specified - if author_id is None: - if isinstance(ctx_or_interaction, commands.Context): - author_id = ctx_or_interaction.author.id - else: - author_id = ctx_or_interaction.user.id - - # Create the paginator - paginator = cls(pages, timeout=timeout, author_id=author_id, start_page=start_page) - - # Get current page content - current_content = pages[start_page] - - # Prepare kwargs based on content type - kwargs = {"view": paginator} - if isinstance(current_content, discord.Embed): - kwargs["embed"] = current_content - else: - kwargs["content"] = str(current_content) - - # Send/respond based on the context type - if isinstance(ctx_or_interaction, commands.Context): - await ctx_or_interaction.send(**kwargs) - else: - # It's an interaction - if ctx_or_interaction.response.is_done(): - await ctx_or_interaction.followup.send(**kwargs, ephemeral=ephemeral) - else: - kwargs["ephemeral"] = ephemeral - await ctx_or_interaction.response.send_message(**kwargs) - - return paginator +__all__ = ["Paginator"] diff --git a/src/utils/class_utils/paginator.py b/src/utils/class_utils/paginator.py new file mode 100644 index 0000000..7d85e96 --- /dev/null +++ b/src/utils/class_utils/paginator.py @@ -0,0 +1,127 @@ +""" +Paginator UI view for navigating through multiple pages of content. +""" +import discord +from discord.ext import commands +from typing import List, Optional, Union + + +class Paginator(discord.ui.View): + """ + A paginator view for navigating through multiple pages of content. + """ + def __init__( + self, + pages: List[Union[discord.Embed, str]], + timeout: int = 60, + author_id: Optional[int] = None, + start_page: int = 0 + ): + super().__init__(timeout=timeout) + self.pages = pages + self.author_id = author_id + self.current_page = start_page + self.max_pages = len(pages) + self.update_buttons() + + def update_buttons(self): + """Update the state of navigation buttons based on current page.""" + self.first_page_button.disabled = (self.current_page == 0) + self.prev_button.disabled = (self.current_page == 0) + self.page_indicator.label = f"{self.current_page + 1}/{self.max_pages}" + self.next_button.disabled = (self.current_page == self.max_pages - 1) + self.last_page_button.disabled = (self.current_page == self.max_pages - 1) + + async def interaction_check(self, interaction: discord.Interaction) -> bool: + if self.author_id is None: + return True + if interaction.user.id == self.author_id: + return True + await interaction.response.send_message( + "You cannot control this pagination because it was not started by you.", + ephemeral=True + ) + return False + + async def on_timeout(self) -> None: + for item in self.children: + item.disabled = True + + @discord.ui.button(emoji="⏮️", style=discord.ButtonStyle.gray) + async def first_page_button(self, interaction: discord.Interaction, button: discord.ui.Button): + self.current_page = 0 + self.update_buttons() + await self.show_current_page(interaction) + + @discord.ui.button(emoji="◀️", style=discord.ButtonStyle.blurple) + async def prev_button(self, interaction: discord.Interaction, button: discord.ui.Button): + self.current_page = max(0, self.current_page - 1) + self.update_buttons() + await self.show_current_page(interaction) + + @discord.ui.button(label="1/1", style=discord.ButtonStyle.gray, disabled=True) + async def page_indicator(self, interaction: discord.Interaction, button: discord.ui.Button): + pass + + @discord.ui.button(emoji="▶️", style=discord.ButtonStyle.blurple) + async def next_button(self, interaction: discord.Interaction, button: discord.ui.Button): + self.current_page = min(self.max_pages - 1, self.current_page + 1) + self.update_buttons() + await self.show_current_page(interaction) + + @discord.ui.button(emoji="⏭️", style=discord.ButtonStyle.gray) + async def last_page_button(self, interaction: discord.Interaction, button: discord.ui.Button): + self.current_page = self.max_pages - 1 + self.update_buttons() + await self.show_current_page(interaction) + + async def show_current_page(self, interaction: discord.Interaction): + current_content = self.pages[self.current_page] + kwargs = {} + if isinstance(current_content, discord.Embed): + kwargs["embed"] = current_content + kwargs["content"] = None + else: + kwargs["content"] = str(current_content) + kwargs["embed"] = None + kwargs["view"] = self + await interaction.response.edit_message(**kwargs) + + @classmethod + async def create_paginator( + cls, + ctx_or_interaction: Union[commands.Context, discord.Interaction], + pages: List[Union[discord.Embed, str]], + timeout: int = 60, + author_id: Optional[int] = None, + start_page: int = 0, + ephemeral: bool = False + ): + """Create and start a paginator.""" + if not pages: + raise ValueError("Cannot create a paginator with no pages") + + if author_id is None: + if isinstance(ctx_or_interaction, commands.Context): + author_id = ctx_or_interaction.author.id + else: + author_id = ctx_or_interaction.user.id + + paginator = cls(pages, timeout=timeout, author_id=author_id, start_page=start_page) + current_content = pages[start_page] + kwargs = {"view": paginator} + if isinstance(current_content, discord.Embed): + kwargs["embed"] = current_content + else: + kwargs["content"] = str(current_content) + + if isinstance(ctx_or_interaction, commands.Context): + await ctx_or_interaction.send(**kwargs) + else: + if ctx_or_interaction.response.is_done(): + await ctx_or_interaction.followup.send(**kwargs, ephemeral=ephemeral) + else: + kwargs["ephemeral"] = ephemeral + await ctx_or_interaction.response.send_message(**kwargs) + + return paginator diff --git a/src/utils/decorators/premium_required.py b/src/utils/decorators/premium_required.py index 42c76b9..8f8d297 100644 --- a/src/utils/decorators/premium_required.py +++ b/src/utils/decorators/premium_required.py @@ -3,120 +3,204 @@ import logging import discord from functools import wraps -from typing import Optional +from typing import Optional, Union from discord.ext import commands +from discord import app_commands from src.database.models.premium import PremiumTier logger = logging.getLogger('premium_decorator') +TIER_HIERARCHY = { + PremiumTier.BASIC: 1, + PremiumTier.PRO: 2, + PremiumTier.ENTERPRISE: 3 +} + + +def _get_guild_and_user(ctx_or_interaction): + """Extract guild_id and user from either a Context or Interaction.""" + if isinstance(ctx_or_interaction, discord.Interaction): + interaction = ctx_or_interaction + user = interaction.user + guild_id = interaction.guild.id if interaction.guild else 0 + return user, guild_id, True + else: + ctx = ctx_or_interaction + user = ctx.author + guild_id = ctx.guild.id if ctx.guild else 0 + return user, guild_id, False + + +async def _send_premium_required(ctx_or_interaction, tier: PremiumTier, is_interaction: bool): + """Send a premium-required embed via ctx or interaction.""" + embed = discord.Embed( + title="🔒 Premium Required", + description=( + f"This command requires **{tier.value.title()} Premium** or higher.\n\n" + f"Use `/premium info` to learn more about premium benefits." + ), + color=discord.Color.gold() + ) + if is_interaction: + if ctx_or_interaction.response.is_done(): + await ctx_or_interaction.followup.send(embed=embed, ephemeral=True) + else: + await ctx_or_interaction.response.send_message(embed=embed, ephemeral=True) + else: + await ctx_or_interaction.send(embed=embed) + + +async def _send_tier_required(ctx_or_interaction, tier: PremiumTier, user_tier_str: str, is_interaction: bool): + """Send a higher-tier-required embed.""" + embed = discord.Embed( + title="🔒 Higher Premium Tier Required", + description=( + f"This command requires **{tier.value.title()} Premium** or higher.\n" + f"Your current tier: **{user_tier_str.title()}**\n\n" + f"Use `/premium upgrade` to upgrade your subscription." + ), + color=discord.Color.gold() + ) + if is_interaction: + if ctx_or_interaction.response.is_done(): + await ctx_or_interaction.followup.send(embed=embed, ephemeral=True) + else: + await ctx_or_interaction.response.send_message(embed=embed, ephemeral=True) + else: + await ctx_or_interaction.send(embed=embed) + def premium_required(tier: Optional[PremiumTier] = PremiumTier.BASIC): - """Decorator to restrict commands to premium users - + """Decorator to restrict commands to premium guilds. + + Works with both prefix commands (ctx) and slash/hybrid commands (interaction). + Args: tier: Minimum required premium tier (default: BASIC) - + Usage: @premium_required() async def my_command(self, ctx): ... - + @premium_required(tier=PremiumTier.PRO) async def pro_command(self, ctx): ... """ def decorator(func): @wraps(func) - async def wrapper(self, ctx, *args, **kwargs): - # Get the bot instance - bot = self.bot if hasattr(self, 'bot') else ctx.bot - - # Get premium model from bot - if not hasattr(bot, 'premium_model'): + async def wrapper(self, ctx_or_interaction, *args, **kwargs): + bot = self.bot if hasattr(self, 'bot') else None + + if bot is None or not hasattr(bot, 'premium_model'): logger.error("Premium model not initialized in bot") - await ctx.send("❌ Premium system is not available.") + is_interaction = isinstance(ctx_or_interaction, discord.Interaction) + if is_interaction: + await ctx_or_interaction.response.send_message( + "❌ Premium system is not available.", ephemeral=True + ) + else: + await ctx_or_interaction.send("❌ Premium system is not available.") return - - # Check premium status - premium_data = await bot.premium_model.get_user_premium( - ctx.author.id, - ctx.guild.id if ctx.guild else 0 - ) - + + user, guild_id, is_interaction = _get_guild_and_user(ctx_or_interaction) + + premium_data = await bot.premium_model.get_guild_premium(guild_id) + if not premium_data: - embed = discord.Embed( - title="🔒 Premium Required", - description=( - f"This command requires **{tier.value.title()} Premium** or higher.\n\n" - f"Use `/premium info` to learn more about premium benefits." - ), - color=discord.Color.gold() - ) - await ctx.send(embed=embed) + await _send_premium_required(ctx_or_interaction, tier, is_interaction) return - - # Check tier level - user_tier = premium_data.get("premium_tier", "basic") - tier_hierarchy = { - PremiumTier.BASIC: 1, - PremiumTier.PRO: 2, - PremiumTier.ENTERPRISE: 3 - } - - required_level = tier_hierarchy.get(tier, 1) - user_level = tier_hierarchy.get(user_tier, 0) - + + user_tier_str = premium_data.get("premium_tier", "basic") + # Safely resolve to PremiumTier enum + try: + user_tier_enum = PremiumTier(user_tier_str) + except ValueError: + user_tier_enum = PremiumTier.BASIC + + required_level = TIER_HIERARCHY.get(tier, 1) + user_level = TIER_HIERARCHY.get(user_tier_enum, 0) + if user_level < required_level: - embed = discord.Embed( - title="🔒 Higher Premium Tier Required", - description=( - f"This command requires **{tier.value.title()} Premium** or higher.\n" - f"Your current tier: **{user_tier.title()}**\n\n" - f"Use `/premium upgrade` to upgrade your subscription." - ), - color=discord.Color.gold() - ) - await ctx.send(embed=embed) + await _send_tier_required(ctx_or_interaction, tier, user_tier_str, is_interaction) return - - # User has required premium, execute command - return await func(self, ctx, *args, **kwargs) - + + return await func(self, ctx_or_interaction, *args, **kwargs) + return wrapper return decorator +def _premium_check(tier: PremiumTier = PremiumTier.BASIC): + """app_commands.check-compatible premium check predicate factory.""" + async def predicate(interaction: discord.Interaction) -> bool: + bot = interaction.client + if not hasattr(bot, 'premium_model'): + raise app_commands.CheckFailure("Premium system is not available.") + + guild_id = interaction.guild.id if interaction.guild else 0 + premium_data = await bot.premium_model.get_guild_premium(guild_id) + + if not premium_data: + raise app_commands.CheckFailure( + f"This command requires **{tier.value.title()} Premium** or higher. " + f"Use `/premium info` to learn more." + ) + + user_tier_str = premium_data.get("premium_tier", "basic") + try: + user_tier_enum = PremiumTier(user_tier_str) + except ValueError: + user_tier_enum = PremiumTier.BASIC + + required_level = TIER_HIERARCHY.get(tier, 1) + user_level = TIER_HIERARCHY.get(user_tier_enum, 0) + + if user_level < required_level: + raise app_commands.CheckFailure( + f"This command requires **{tier.value.title()} Premium** or higher. " + f"Your current tier: **{user_tier_str.title()}**." + ) + + return True + + return predicate + + +def app_premium_required(tier: PremiumTier = PremiumTier.BASIC): + """app_commands.check decorator for slash-only commands.""" + return app_commands.check(_premium_check(tier)) + + def feature_required(feature: str): - """Decorator to restrict commands based on specific premium features - + """Decorator to restrict commands based on specific premium features. + + Works with both prefix commands (ctx) and slash/hybrid commands (interaction). + Args: feature: Feature name to check - - Usage: - @feature_required("api_access") - async def api_command(self, ctx): - ... """ def decorator(func): @wraps(func) - async def wrapper(self, ctx, *args, **kwargs): - # Get the bot instance - bot = self.bot if hasattr(self, 'bot') else ctx.bot - - # Get premium model from bot - if not hasattr(bot, 'premium_model'): + async def wrapper(self, ctx_or_interaction, *args, **kwargs): + bot = self.bot if hasattr(self, 'bot') else None + + if bot is None or not hasattr(bot, 'premium_model'): logger.error("Premium model not initialized in bot") - await ctx.send("❌ Premium system is not available.") + is_interaction = isinstance(ctx_or_interaction, discord.Interaction) + if is_interaction: + await ctx_or_interaction.response.send_message( + "❌ Premium system is not available.", ephemeral=True + ) + else: + await ctx_or_interaction.send("❌ Premium system is not available.") return - - # Check feature access - has_access = await bot.premium_model.has_feature( - ctx.author.id, - ctx.guild.id if ctx.guild else 0, - feature - ) - + + _, guild_id, is_interaction = _get_guild_and_user(ctx_or_interaction) + + has_access = await bot.premium_model.has_feature(0, guild_id, feature) + if not has_access: embed = discord.Embed( title="🔒 Premium Feature", @@ -126,11 +210,16 @@ async def wrapper(self, ctx, *args, **kwargs): ), color=discord.Color.gold() ) - await ctx.send(embed=embed) + if is_interaction: + if ctx_or_interaction.response.is_done(): + await ctx_or_interaction.followup.send(embed=embed, ephemeral=True) + else: + await ctx_or_interaction.response.send_message(embed=embed, ephemeral=True) + else: + await ctx_or_interaction.send(embed=embed) return - - # User has access, execute command - return await func(self, ctx, *args, **kwargs) - + + return await func(self, ctx_or_interaction, *args, **kwargs) + return wrapper return decorator diff --git a/src/utils/discord/__init__.py b/src/utils/discord/__init__.py index 056d6af..3a03b33 100644 --- a/src/utils/discord/__init__.py +++ b/src/utils/discord/__init__.py @@ -1,134 +1,16 @@ """ Discord helper utilities for common operations. -This module provides helper functions for Discord-related operations. """ -import discord -from discord.ext import commands -import logging -from typing import Optional, Union, List, Dict, Any, Tuple, TypeVar +from .helpers import ( + ContextOrInteraction, + check_if_ctx_or_interaction, + respond_to_ctx_or_interaction, + create_basic_embed, +) -logger = logging.getLogger(__name__) - -# Type aliases for clarity -ContextOrInteraction = Union[commands.Context, discord.Interaction] -T = TypeVar('T') - -def check_if_ctx_or_interaction(ctx_or_interaction: ContextOrInteraction) -> Tuple[bool, bool]: - """ - Check if the provided object is a Context or Interaction and return appropriate flags. - - Args: - ctx_or_interaction: Either a commands.Context or discord.Interaction object - - Returns: - Tuple[bool, bool]: (is_context, is_interaction) - """ - is_context = isinstance(ctx_or_interaction, commands.Context) - is_interaction = isinstance(ctx_or_interaction, discord.Interaction) - - return is_context, is_interaction - -async def respond_to_ctx_or_interaction( - ctx_or_interaction: ContextOrInteraction, - content: Optional[str] = None, - embed: Optional[discord.Embed] = None, - view: Optional[discord.ui.View] = None, - ephemeral: bool = False -) -> None: - """ - Respond to either a Context or Interaction with consistent handling. - - Args: - ctx_or_interaction: Either a commands.Context or discord.Interaction - content: Text content to send - embed: Embed to send - view: View to attach - ephemeral: Whether the response should be ephemeral (only applies to interactions) - """ - is_context, is_interaction = check_if_ctx_or_interaction(ctx_or_interaction) - - try: - if is_interaction: - # Handle interaction response - interaction = ctx_or_interaction - - # Check if we've already responded - if interaction.response.is_done(): - # If we've already responded, use followup - await interaction.followup.send( - content=content, - embed=embed, - view=view, - ephemeral=ephemeral - ) - else: - # Initial response - await interaction.response.send_message( - content=content, - embed=embed, - view=view, - ephemeral=ephemeral - ) - elif is_context: - # Handle context response - await ctx_or_interaction.send( - content=content, - embed=embed, - view=view - ) - else: - logger.error(f"Unknown type for response: {type(ctx_or_interaction)}") - - except Exception as e: - logger.error(f"Error responding to context/interaction: {e}") - - # Try fallback method if possible - try: - if hasattr(ctx_or_interaction, 'send'): - await ctx_or_interaction.send( - content=content or "An error occurred while processing your request.", - embed=embed - ) - except Exception as fallback_error: - logger.error(f"Fallback response also failed: {fallback_error}") - -def create_basic_embed( - title: Optional[str] = None, - description: Optional[str] = None, - color: Optional[discord.Color] = None, - footer: Optional[str] = None, - thumbnail: Optional[str] = None, - image: Optional[str] = None -) -> discord.Embed: - """ - Create a basic Discord embed with common parameters. - - Args: - title: Title of the embed - description: Description of the embed - color: Color of the embed - footer: Footer text - thumbnail: URL for thumbnail - image: URL for image - - Returns: - discord.Embed: The created embed - """ - embed = discord.Embed( - title=title, - description=description, - color=color or discord.Color.blurple() - ) - - if footer: - embed.set_footer(text=footer) - - if thumbnail: - embed.set_thumbnail(url=thumbnail) - - if image: - embed.set_image(url=image) - - return embed - -# Other helper functions can be added here as needed +__all__ = [ + "ContextOrInteraction", + "check_if_ctx_or_interaction", + "respond_to_ctx_or_interaction", + "create_basic_embed", +] diff --git a/src/utils/discord/helpers.py b/src/utils/discord/helpers.py new file mode 100644 index 0000000..6765661 --- /dev/null +++ b/src/utils/discord/helpers.py @@ -0,0 +1,81 @@ +""" +Discord helper utilities for common operations. +""" +import discord +from discord.ext import commands +import logging +from typing import Optional, Union, Tuple, TypeVar + +logger = logging.getLogger(__name__) + +ContextOrInteraction = Union[commands.Context, discord.Interaction] +T = TypeVar('T') + + +def check_if_ctx_or_interaction(ctx_or_interaction: ContextOrInteraction) -> Tuple[bool, bool]: + """Return (is_context, is_interaction) flags.""" + return ( + isinstance(ctx_or_interaction, commands.Context), + isinstance(ctx_or_interaction, discord.Interaction), + ) + + +async def respond_to_ctx_or_interaction( + ctx_or_interaction: ContextOrInteraction, + content: Optional[str] = None, + embed: Optional[discord.Embed] = None, + view: Optional[discord.ui.View] = None, + ephemeral: bool = False +) -> None: + """Respond to either a Context or Interaction with consistent handling.""" + is_context, is_interaction = check_if_ctx_or_interaction(ctx_or_interaction) + + try: + if is_interaction: + interaction = ctx_or_interaction + if interaction.response.is_done(): + await interaction.followup.send( + content=content, embed=embed, view=view, ephemeral=ephemeral + ) + else: + await interaction.response.send_message( + content=content, embed=embed, view=view, ephemeral=ephemeral + ) + elif is_context: + await ctx_or_interaction.send(content=content, embed=embed, view=view) + else: + logger.error(f"Unknown type for response: {type(ctx_or_interaction)}") + + except Exception as e: + logger.error(f"Error responding to context/interaction: {e}") + try: + if hasattr(ctx_or_interaction, 'send'): + await ctx_or_interaction.send( + content=content or "An error occurred while processing your request.", + embed=embed + ) + except Exception as fallback_error: + logger.error(f"Fallback response also failed: {fallback_error}") + + +def create_basic_embed( + title: Optional[str] = None, + description: Optional[str] = None, + color: Optional[discord.Color] = None, + footer: Optional[str] = None, + thumbnail: Optional[str] = None, + image: Optional[str] = None +) -> discord.Embed: + """Create a basic Discord embed with common parameters.""" + embed = discord.Embed( + title=title, + description=description, + color=color or discord.Color.blurple() + ) + if footer: + embed.set_footer(text=footer) + if thumbnail: + embed.set_thumbnail(url=thumbnail) + if image: + embed.set_image(url=image) + return embed diff --git a/src/utils/formatting/__init__.py b/src/utils/formatting/__init__.py index 3598aa3..c2d962f 100644 --- a/src/utils/formatting/__init__.py +++ b/src/utils/formatting/__init__.py @@ -1,238 +1,24 @@ """ Formatting utilities for Discord messages, embeds, and other content. """ -import discord -import re -import datetime -import logging -from typing import Optional, Union, Dict, Any, List - -def create_embed( - title: Optional[str] = None, - description: Optional[str] = None, - color: Optional[Union[discord.Color, int]] = None, - author: Optional[Dict[str, Any]] = None, - fields: Optional[List[Dict[str, Any]]] = None, - footer: Optional[Dict[str, Any]] = None, - image: Optional[str] = None, - thumbnail: Optional[str] = None, - timestamp: Optional[datetime.datetime] = None -) -> discord.Embed: - """ - Create a Discord embed with specified parameters. - - Args: - title: The title of the embed - description: The description content - color: The color of the embed sidebar - author: Dict with 'name', 'url' (optional), 'icon_url' (optional) - fields: List of dicts with 'name', 'value', 'inline' (optional) - footer: Dict with 'text', 'icon_url' (optional) - image: URL of the main image - thumbnail: URL of the thumbnail image - timestamp: Datetime for the embed timestamp - - Returns: - discord.Embed: The created embed - """ - # Default color if none provided - if color is None: - color = discord.Color.blurple() - - # Create the embed - embed = discord.Embed( - title=title, - description=description, - color=color, - timestamp=timestamp - ) - - # Add author if provided - if author: - name = author.get('name', '') - url = author.get('url') - icon_url = author.get('icon_url') - embed.set_author(name=name, url=url, icon_url=icon_url) - - # Add fields if provided - if fields: - for field in fields: - name = field.get('name', '') - value = field.get('value', '') - inline = field.get('inline', False) - embed.add_field(name=name, value=value, inline=inline) - - # Add footer if provided - if footer: - text = footer.get('text', '') - icon_url = footer.get('icon_url') - embed.set_footer(text=text, icon_url=icon_url) - - # Add images if provided - if image: - embed.set_image(url=image) - - if thumbnail: - embed.set_thumbnail(url=thumbnail) - - return embed - -def truncate_text(text: str, max_length: int = 2000, suffix: str = '...') -> str: - """ - Truncate text to the specified maximum length, adding a suffix if truncated. - - Args: - text: The text to truncate - max_length: Maximum length allowed - suffix: Suffix to add when truncated - - Returns: - str: Truncated text - """ - if len(text) <= max_length: - return text - - return text[:max_length - len(suffix)] + suffix - -def format_timestamp(timestamp: Union[int, float, datetime.datetime], format_type: str = 'f') -> str: - """ - Format a timestamp for Discord display. - - Args: - timestamp: Unix timestamp or datetime object - format_type: Discord timestamp format (t, T, d, D, f, F, R) - - Returns: - str: Formatted Discord timestamp - """ - if isinstance(timestamp, datetime.datetime): - timestamp = int(timestamp.timestamp()) - - return f"" - -def sanitize_mentions(text: str) -> str: - """ - Sanitize Discord mentions in a text string. - - Args: - text: Text to sanitize - - Returns: - str: Sanitized text with escaped mentions - """ - # Escape @everyone and @here - text = text.replace('@everyone', '@\u200beveryone') - text = text.replace('@here', '@\u200bhere') - - # Escape user/role/channel mentions - text = re.sub(r'<@(!?&?\d+)>', '<@\u200b\\1>', text) - - return text - -def create_progress_bar(progress: float, length: int = 10, - filled_char: str = '■', empty_char: str = '□') -> str: - """ - Create a text-based progress bar. - - Args: - progress: Progress value between 0 and 1 - length: Number of characters in the bar - filled_char: Character for filled portion - empty_char: Character for empty portion - - Returns: - str: Text progress bar - """ - # Ensure progress is between 0 and 1 - progress = max(0, min(1, progress)) - - # Calculate filled and empty lengths - filled_length = int(progress * length) - empty_length = length - filled_length - - # Construct the bar - bar = filled_char * filled_length + empty_char * empty_length - - return bar - -def hex_to_int(hex_string: str) -> int: - """ - Convert a hexadecimal color string to an integer. - - Args: - hex_string: Hex color string (with or without # prefix) - - Returns: - int: Integer representation of the color - """ - # Remove '#' if present - if hex_string.startswith('#'): - hex_string = hex_string[1:] - - # Validate hexadecimal format - if not all(c in '0123456789ABCDEFabcdef' for c in hex_string): - raise ValueError(f"Invalid hexadecimal color string: {hex_string}") - - # Convert to integer - return int(hex_string, 16) - -def calculate_how_long_ago_member_joined(member) -> str: - """ - Calculate and format how long ago a member joined the server. - - Args: - member: discord.Member object - - Returns: - str: Formatted time string - """ - if not member.joined_at: - return "Unknown" - - now = datetime.datetime.now(datetime.timezone.utc) - joined = member.joined_at - - # Calculate time difference - delta = now - joined - - # Format the difference in a human-readable way - days = delta.days - years = days // 365 - months = (days % 365) // 30 - remaining_days = days % 30 - - if years > 0: - return f"{years} year{'s' if years != 1 else ''}, {months} month{'s' if months != 1 else ''}, {remaining_days} day{'s' if remaining_days != 1 else ''}" - elif months > 0: - return f"{months} month{'s' if months != 1 else ''}, {remaining_days} day{'s' if remaining_days != 1 else ''}" - else: - return f"{remaining_days} day{'s' if remaining_days != 1 else ''}" - -def calculate_how_long_ago_member_created(member) -> str: - """ - Calculate and format how long ago a member created their account. - - Args: - member: discord.Member object - - Returns: - str: Formatted time string - """ - now = datetime.datetime.now(datetime.timezone.utc) - created = member.created_at - - # Calculate time difference - delta = now - created - - # Format the difference in a human-readable way - days = delta.days - years = days // 365 - months = (days % 365) // 30 - remaining_days = days % 30 - - if years > 0: - return f"{years} year{'s' if years != 1 else ''}, {months} month{'s' if months != 1 else ''}, {remaining_days} day{'s' if remaining_days != 1 else ''}" - elif months > 0: - return f"{months} month{'s' if months != 1 else ''}, {remaining_days} day{'s' if remaining_days != 1 else ''}" - else: - return f"{remaining_days} day{'s' if remaining_days != 1 else ''}" +from .embeds import create_embed +from .text import ( + truncate_text, + format_timestamp, + sanitize_mentions, + create_progress_bar, + hex_to_int, + calculate_how_long_ago_member_joined, + calculate_how_long_ago_member_created, +) + +__all__ = [ + "create_embed", + "truncate_text", + "format_timestamp", + "sanitize_mentions", + "create_progress_bar", + "hex_to_int", + "calculate_how_long_ago_member_joined", + "calculate_how_long_ago_member_created", +] diff --git a/src/utils/formatting/embeds.py b/src/utils/formatting/embeds.py new file mode 100644 index 0000000..a3ba08c --- /dev/null +++ b/src/utils/formatting/embeds.py @@ -0,0 +1,55 @@ +""" +Discord embed creation utilities. +""" +import discord +import datetime +from typing import Optional, Union, Dict, Any, List + + +def create_embed( + title: Optional[str] = None, + description: Optional[str] = None, + color: Optional[Union[discord.Color, int]] = None, + author: Optional[Dict[str, Any]] = None, + fields: Optional[List[Dict[str, Any]]] = None, + footer: Optional[Dict[str, Any]] = None, + image: Optional[str] = None, + thumbnail: Optional[str] = None, + timestamp: Optional[datetime.datetime] = None +) -> discord.Embed: + """Create a Discord embed with specified parameters.""" + if color is None: + color = discord.Color.blurple() + + embed = discord.Embed( + title=title, + description=description, + color=color, + timestamp=timestamp + ) + + if author: + embed.set_author( + name=author.get('name', ''), + url=author.get('url'), + icon_url=author.get('icon_url') + ) + + if fields: + for field in fields: + embed.add_field( + name=field.get('name', ''), + value=field.get('value', ''), + inline=field.get('inline', False) + ) + + if footer: + embed.set_footer(text=footer.get('text', ''), icon_url=footer.get('icon_url')) + + if image: + embed.set_image(url=image) + + if thumbnail: + embed.set_thumbnail(url=thumbnail) + + return embed diff --git a/src/utils/formatting/text.py b/src/utils/formatting/text.py new file mode 100644 index 0000000..1f3e7cc --- /dev/null +++ b/src/utils/formatting/text.py @@ -0,0 +1,86 @@ +""" +Text and timestamp formatting utilities. +""" +import re +import datetime +from typing import Union + + +def truncate_text(text: str, max_length: int = 2000, suffix: str = '...') -> str: + """Truncate text to the specified maximum length, adding a suffix if truncated.""" + if len(text) <= max_length: + return text + return text[:max_length - len(suffix)] + suffix + + +def format_timestamp( + timestamp: Union[int, float, datetime.datetime], + format_type: str = 'f' +) -> str: + """Format a timestamp for Discord display. + + format_type options: t, T, d, D, f, F, R + """ + if isinstance(timestamp, datetime.datetime): + timestamp = int(timestamp.timestamp()) + return f"" + + +def sanitize_mentions(text: str) -> str: + """Sanitize Discord mentions in a text string.""" + text = text.replace('@everyone', '@\u200beveryone') + text = text.replace('@here', '@\u200bhere') + text = re.sub(r'<@(!?&?\d+)>', '<@\u200b\\1>', text) + return text + + +def create_progress_bar( + progress: float, + length: int = 10, + filled_char: str = '■', + empty_char: str = '□' +) -> str: + """Create a text-based progress bar.""" + progress = max(0, min(1, progress)) + filled_length = int(progress * length) + empty_length = length - filled_length + return filled_char * filled_length + empty_char * empty_length + + +def hex_to_int(hex_string: str) -> int: + """Convert a hexadecimal color string to an integer.""" + if hex_string.startswith('#'): + hex_string = hex_string[1:] + if not all(c in '0123456789ABCDEFabcdef' for c in hex_string): + raise ValueError(f"Invalid hexadecimal color string: {hex_string}") + return int(hex_string, 16) + + +def calculate_how_long_ago_member_joined(member) -> str: + """Calculate and format how long ago a member joined the server.""" + if not member.joined_at: + return "Unknown" + now = datetime.datetime.now(datetime.timezone.utc) + return _format_timedelta(now - member.joined_at) + + +def calculate_how_long_ago_member_created(member) -> str: + """Calculate and format how long ago a member created their account.""" + now = datetime.datetime.now(datetime.timezone.utc) + return _format_timedelta(now - member.created_at) + + +def _format_timedelta(delta: datetime.timedelta) -> str: + days = delta.days + years = days // 365 + months = (days % 365) // 30 + remaining_days = days % 30 + + def plural(n, word): + return f"{n} {word}{'s' if n != 1 else ''}" + + if years > 0: + return f"{plural(years, 'year')}, {plural(months, 'month')}, {plural(remaining_days, 'day')}" + elif months > 0: + return f"{plural(months, 'month')}, {plural(remaining_days, 'day')}" + return plural(remaining_days, 'day')