From 608bfd2215ffc4b57dfdd68f5435bec10fbc7c31 Mon Sep 17 00:00:00 2001 From: Tanush1912 <2541228@dundee.ac.uk> Date: Mon, 2 Mar 2026 01:58:52 +0000 Subject: [PATCH] Add GitHub App integration, enhance live mode, and improve diff engine Integrate GitHub App OAuth flow with installation management, encrypted token storage, and webhook handling. Enhance live mode with improved handshake visualization and source-file-aware diffing. Improve the diff engine with better endpoint matching and add tests. Overhaul the settings page UI and add a Vercel deploy workflow. --- .github/workflows/deploy.yml | 55 +++ cohesion_backend/.env.example | 8 +- cohesion_backend/.gitignore | 6 +- cohesion_backend/cmd/server/main.go | 23 +- cohesion_backend/go.mod | 3 + cohesion_backend/go.sum | 6 + cohesion_backend/internal/config/config.go | 20 + .../controlplane/handlers/github_app.go | 117 ++++++ .../controlplane/handlers/handlers.go | 90 ++-- .../controlplane/handlers/live_handlers.go | 131 +++++- .../internal/controlplane/router.go | 38 +- cohesion_backend/internal/crypto/crypto.go | 101 +++++ cohesion_backend/internal/models/models.go | 10 + .../internal/repository/diff_repo.go | 14 +- .../internal/repository/endpoint_repo.go | 120 +++++- .../repository/github_installation_repo.go | 73 ++++ .../internal/repository/project_repo.go | 13 +- .../internal/repository/user_settings_repo.go | 27 +- .../internal/services/diff_service.go | 109 ++--- .../internal/services/endpoint_service.go | 10 - .../services/github_installation_service.go | 41 ++ .../internal/services/live_service.go | 134 +++--- .../internal/services/schema_service.go | 59 ++- .../005_github_app_installations.down.sql | 1 + .../005_github_app_installations.up.sql | 11 + .../pkg/analyzer/gemini/filediscovery.go | 80 +--- cohesion_backend/pkg/analyzer/interface.go | 28 -- cohesion_backend/pkg/diff/engine.go | 198 ++++++--- cohesion_backend/pkg/diff/engine_test.go | 93 +++++ cohesion_backend/pkg/diff/types.go | 5 +- cohesion_backend/pkg/github/appauth.go | 40 ++ cohesion_backend/pkg/github/fetcher.go | 54 +-- cohesion_backend/pkg/sourcefile/sourcefile.go | 79 ++++ cohesion_frontend/.env.example | 3 +- cohesion_frontend/src/app/live/page.tsx | 35 +- cohesion_frontend/src/app/settings/page.tsx | 239 +++++++++-- .../components/live/live-handshake-view.tsx | 383 ++++++++++++++++++ .../src/components/live/live-onboarding.tsx | 8 +- .../project/upload-schema-dialog.tsx | 2 +- .../visualization/contract-schema.tsx | 12 +- .../components/visualization/diff-panel.tsx | 2 +- .../visualization/handshake-view.tsx | 11 +- .../components/visualization/schema-tree.tsx | 3 +- cohesion_frontend/src/lib/api.ts | 27 +- 44 files changed, 2041 insertions(+), 481 deletions(-) create mode 100644 .github/workflows/deploy.yml create mode 100644 cohesion_backend/internal/controlplane/handlers/github_app.go create mode 100644 cohesion_backend/internal/crypto/crypto.go create mode 100644 cohesion_backend/internal/repository/github_installation_repo.go create mode 100644 cohesion_backend/internal/services/github_installation_service.go create mode 100644 cohesion_backend/migrations/005_github_app_installations.down.sql create mode 100644 cohesion_backend/migrations/005_github_app_installations.up.sql create mode 100644 cohesion_backend/pkg/github/appauth.go create mode 100644 cohesion_backend/pkg/sourcefile/sourcefile.go create mode 100644 cohesion_frontend/src/components/live/live-handshake-view.tsx diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml new file mode 100644 index 0000000..ae7d77e --- /dev/null +++ b/.github/workflows/deploy.yml @@ -0,0 +1,55 @@ +name: Deploy Frontend to Vercel + +on: + push: + branches: [main] + paths: + - "cohesion_frontend/**" + - ".github/workflows/deploy.yml" + pull_request: + branches: [main] + paths: + - "cohesion_frontend/**" + - ".github/workflows/deploy.yml" + +env: + VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }} + VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }} + +jobs: + deploy: + runs-on: ubuntu-latest + defaults: + run: + working-directory: cohesion_frontend + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Vercel CLI + run: npm install -g vercel + + - name: Pull Vercel Environment + run: vercel pull --yes --environment=${{ github.event_name == 'push' && 'production' || 'preview' }} --token=${{ secrets.VERCEL_TOKEN }} + + - name: Build + run: vercel build ${{ github.event_name == 'push' && '--prod' || '' }} --token=${{ secrets.VERCEL_TOKEN }} + + - name: Deploy + id: deploy + run: | + URL=$(vercel deploy --prebuilt ${{ github.event_name == 'push' && '--prod' || '' }} --token=${{ secrets.VERCEL_TOKEN }}) + echo "url=$URL" >> "$GITHUB_OUTPUT" + + - name: Comment Preview URL on PR + if: github.event_name == 'pull_request' + uses: actions/github-script@v7 + with: + script: | + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `**Vercel Preview:** ${{ steps.deploy.outputs.url }}` + }) diff --git a/cohesion_backend/.env.example b/cohesion_backend/.env.example index ec956b3..a492766 100644 --- a/cohesion_backend/.env.example +++ b/cohesion_backend/.env.example @@ -4,4 +4,10 @@ ENVIRONMENT=development GEMINI_API_KEY=your-gemini-api-key GEMINI_MODEL=gemini-2.0-flash CLERK_SECRET_KEY=sk_test_... -CLERK_PUBLISHABLE_KEY=pk_test_... \ No newline at end of file +CLERK_PUBLISHABLE_KEY=pk_test_... +GITHUB_APP_ID=123456 +GITHUB_APP_PRIVATE_KEY_PATH=./github-app-private-key.pem +GITHUB_APP_CLIENT_ID=Iv1.abc123 +GITHUB_APP_CLIENT_SECRET=secret +GITHUB_APP_SLUG=cohesion +FRONTEND_URL=http://localhost:3000 \ No newline at end of file diff --git a/cohesion_backend/.gitignore b/cohesion_backend/.gitignore index f1000c4..2ad566e 100644 --- a/cohesion_backend/.gitignore +++ b/cohesion_backend/.gitignore @@ -1,5 +1,5 @@ .env -.env.local +.env.* *.exe *.exe~ *.dll @@ -13,3 +13,7 @@ vendor/ .vscode/ *.log tmp/ +bin/ +server +cohesion-server +.next/ diff --git a/cohesion_backend/cmd/server/main.go b/cohesion_backend/cmd/server/main.go index 641a571..50f4c4b 100644 --- a/cohesion_backend/cmd/server/main.go +++ b/cohesion_backend/cmd/server/main.go @@ -16,6 +16,7 @@ import ( "github.com/cohesion-api/cohesion_backend/internal/services" "github.com/cohesion-api/cohesion_backend/pkg/analyzer" geminianalyzer "github.com/cohesion-api/cohesion_backend/pkg/analyzer/gemini" + ghpkg "github.com/cohesion-api/cohesion_backend/pkg/github" ) func main() { @@ -39,6 +40,7 @@ func main() { schemaRepo := repository.NewSchemaRepository(db) diffRepo := repository.NewDiffRepository(db) userSettingsRepo := repository.NewUserSettingsRepository(db) + ghInstallRepo := repository.NewGitHubInstallationRepository(db) projectService := services.NewProjectService(projectRepo, endpointRepo) endpointService := services.NewEndpointService(endpointRepo, schemaRepo) @@ -46,19 +48,26 @@ func main() { diffService := services.NewDiffService(diffRepo, schemaRepo, endpointRepo) liveService := services.NewLiveService() userSettingsService := services.NewUserSettingsService(userSettingsRepo) + ghInstallService := services.NewGitHubInstallationService(ghInstallRepo) var codeAnalyzer analyzer.Analyzer if cfg.GeminiAPIKey != "" { codeAnalyzer = geminianalyzer.New(cfg.GeminiAPIKey, cfg.GeminiModel) } + ghAppAuth := ghpkg.NewAppAuth(cfg.GitHubAppID, cfg.GitHubAppPrivateKey) + svc := &controlplane.Services{ - ProjectService: projectService, - EndpointService: endpointService, - SchemaService: schemaService, - DiffService: diffService, - LiveService: liveService, - UserSettingsService: userSettingsService, - Analyzer: codeAnalyzer, + ProjectService: projectService, + EndpointService: endpointService, + SchemaService: schemaService, + DiffService: diffService, + LiveService: liveService, + UserSettingsService: userSettingsService, + GitHubInstallationService: ghInstallService, + Analyzer: codeAnalyzer, + GitHubAppAuth: ghAppAuth, + GitHubAppSlug: cfg.GitHubAppSlug, + FrontendURL: cfg.FrontendURL, } router := controlplane.NewRouter(svc) diff --git a/cohesion_backend/go.mod b/cohesion_backend/go.mod index 091c782..2b6416f 100644 --- a/cohesion_backend/go.mod +++ b/cohesion_backend/go.mod @@ -3,6 +3,7 @@ module github.com/cohesion-api/cohesion_backend go 1.25.0 require ( + github.com/bradleyfalzon/ghinstallation/v2 v2.17.0 github.com/clerk/clerk-sdk-go/v2 v2.5.1 github.com/go-chi/chi/v5 v5.0.12 github.com/go-chi/cors v1.2.1 @@ -27,6 +28,8 @@ require ( github.com/go-jose/go-jose/v3 v3.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang-jwt/jwt/v4 v4.5.2 // indirect + github.com/google/go-github/v75 v75.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.12 // indirect diff --git a/cohesion_backend/go.sum b/cohesion_backend/go.sum index 2483d0c..f45557c 100644 --- a/cohesion_backend/go.sum +++ b/cohesion_backend/go.sum @@ -10,6 +10,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= cloud.google.com/go/longrunning v0.5.7 h1:WLbHekDbjK1fVFD3ibpFFVoyizlLRl73I7YKuAKilhU= cloud.google.com/go/longrunning v0.5.7/go.mod h1:8GClkudohy1Fxm3owmBGid8W0pSgodEMwEAztp38Xng= +github.com/bradleyfalzon/ghinstallation/v2 v2.17.0 h1:SmbUK/GxpAspRjSQbB6ARvH+ArzlNzTtHydNyXUQ6zg= +github.com/bradleyfalzon/ghinstallation/v2 v2.17.0/go.mod h1:vuD/xvJT9Y+ZVZRv4HQ42cMyPFIYqpc7AbB4Gvt/DlY= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/clerk/clerk-sdk-go/v2 v2.5.1 h1:RsakGNW6ie83b9KIRtKzqDXBJ//cURy9SJUbGhrsIKg= @@ -37,6 +39,8 @@ github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/generative-ai-go v0.20.1 h1:6dEIujpgN2V0PgLhr6c/M1ynRdc7ARtiIDPFzj45uNQ= @@ -47,6 +51,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-github/v68 v68.0.0 h1:ZW57zeNZiXTdQ16qrDiZ0k6XucrxZ2CGmoTvcCyQG6s= github.com/google/go-github/v68 v68.0.0/go.mod h1:K9HAUBovM2sLwM408A18h+wd9vqdLOEqTUCbnRIcx68= +github.com/google/go-github/v75 v75.0.0 h1:k7q8Bvg+W5KxRl9Tjq16a9XEgVY1pwuiG5sIL7435Ic= +github.com/google/go-github/v75 v75.0.0/go.mod h1:H3LUJEA1TCrzuUqtdAQniBNwuKiQIqdGKgBo1/M/uqI= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= diff --git a/cohesion_backend/internal/config/config.go b/cohesion_backend/internal/config/config.go index 443d27b..d918f85 100644 --- a/cohesion_backend/internal/config/config.go +++ b/cohesion_backend/internal/config/config.go @@ -13,12 +13,25 @@ type Config struct { Environment string GeminiAPIKey string GeminiModel string + + GitHubAppID int64 + GitHubAppPrivateKey []byte + GitHubAppClientID string + GitHubAppClientSecret string + GitHubAppSlug string + FrontendURL string } func Load() *Config { godotenv.Load() port, _ := strconv.Atoi(getEnv("PORT", "8080")) + appID, _ := strconv.ParseInt(getEnv("GITHUB_APP_ID", "0"), 10, 64) + + var privateKey []byte + if path := getEnv("GITHUB_APP_PRIVATE_KEY_PATH", ""); path != "" { + privateKey, _ = os.ReadFile(path) + } return &Config{ DatabaseURL: getEnv("DATABASE_URL", ""), @@ -26,6 +39,13 @@ func Load() *Config { Environment: getEnv("ENVIRONMENT", "development"), GeminiAPIKey: getEnv("GEMINI_API_KEY", ""), GeminiModel: getEnv("GEMINI_MODEL", "gemini-2.0-flash"), + + GitHubAppID: appID, + GitHubAppPrivateKey: privateKey, + GitHubAppClientID: getEnv("GITHUB_APP_CLIENT_ID", ""), + GitHubAppClientSecret: getEnv("GITHUB_APP_CLIENT_SECRET", ""), + GitHubAppSlug: getEnv("GITHUB_APP_SLUG", ""), + FrontendURL: getEnv("FRONTEND_URL", "http://localhost:3000"), } } diff --git a/cohesion_backend/internal/controlplane/handlers/github_app.go b/cohesion_backend/internal/controlplane/handlers/github_app.go new file mode 100644 index 0000000..38091ff --- /dev/null +++ b/cohesion_backend/internal/controlplane/handlers/github_app.go @@ -0,0 +1,117 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "strconv" + + "github.com/cohesion-api/cohesion_backend/internal/auth" + "github.com/go-chi/chi/v5" +) + +func (h *Handlers) GitHubAppStatus(w http.ResponseWriter, r *http.Request) { + configured := h.githubAppAuth.IsConfigured() + resp := map[string]interface{}{ + "configured": configured, + } + if configured && h.githubAppSlug != "" { + resp["install_url"] = fmt.Sprintf("https://github.com/apps/%s/installations/new", h.githubAppSlug) + } + respondJSON(w, http.StatusOK, resp) +} + +type SaveInstallationRequest struct { + InstallationID int64 `json:"installation_id"` +} + +func (h *Handlers) SaveGitHubInstallation(w http.ResponseWriter, r *http.Request) { + userID := auth.UserID(r.Context()) + if userID == "" { + respondError(w, http.StatusUnauthorized, "Not authenticated") + return + } + + var req SaveInstallationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + respondError(w, http.StatusBadRequest, "Invalid request body") + return + } + + if req.InstallationID == 0 { + respondError(w, http.StatusBadRequest, "installation_id is required") + return + } + + if !h.githubAppAuth.IsConfigured() { + respondError(w, http.StatusBadRequest, "GitHub App is not configured") + return + } + + // Verify the installation exists via the GitHub API + appClient, err := h.githubAppAuth.AppClient() + if err != nil { + log.Printf("Failed to create GitHub App client: %v", err) + respondError(w, http.StatusInternalServerError, "Failed to verify installation") + return + } + + installation, _, err := appClient.Apps.GetInstallation(r.Context(), req.InstallationID) + if err != nil { + respondError(w, http.StatusBadRequest, "Invalid installation — make sure you completed the GitHub App install flow") + return + } + + accountLogin := installation.GetAccount().GetLogin() + accountType := installation.GetAccount().GetType() + + if err := h.ghInstallService.SaveInstallation(r.Context(), userID, req.InstallationID, accountLogin, accountType); err != nil { + respondError(w, http.StatusInternalServerError, "Failed to save installation") + return + } + + respondJSON(w, http.StatusCreated, map[string]interface{}{ + "message": "Installation saved", + "installation_id": req.InstallationID, + "github_account_login": accountLogin, + "github_account_type": accountType, + }) +} + +func (h *Handlers) ListGitHubInstallations(w http.ResponseWriter, r *http.Request) { + userID := auth.UserID(r.Context()) + if userID == "" { + respondError(w, http.StatusUnauthorized, "Not authenticated") + return + } + + installations, err := h.ghInstallService.List(r.Context(), userID) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to list installations") + return + } + + respondJSON(w, http.StatusOK, installations) +} + +func (h *Handlers) RemoveGitHubInstallation(w http.ResponseWriter, r *http.Request) { + userID := auth.UserID(r.Context()) + if userID == "" { + respondError(w, http.StatusUnauthorized, "Not authenticated") + return + } + + installationID, err := strconv.ParseInt(chi.URLParam(r, "installationID"), 10, 64) + if err != nil { + respondError(w, http.StatusBadRequest, "Invalid installation ID") + return + } + + if err := h.ghInstallService.Remove(r.Context(), userID, installationID); err != nil { + respondError(w, http.StatusNotFound, "Installation not found") + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/cohesion_backend/internal/controlplane/handlers/handlers.go b/cohesion_backend/internal/controlplane/handlers/handlers.go index 9fccfec..c978208 100644 --- a/cohesion_backend/internal/controlplane/handlers/handlers.go +++ b/cohesion_backend/internal/controlplane/handlers/handlers.go @@ -3,6 +3,7 @@ package handlers import ( "encoding/json" "fmt" + "log" "net/http" "net/url" "strconv" @@ -11,20 +12,23 @@ import ( "github.com/cohesion-api/cohesion_backend/internal/auth" "github.com/cohesion-api/cohesion_backend/internal/models" + "github.com/cohesion-api/cohesion_backend/internal/repository" "github.com/cohesion-api/cohesion_backend/internal/services" "github.com/cohesion-api/cohesion_backend/pkg/analyzer" "github.com/cohesion-api/cohesion_backend/pkg/analyzer/gemini" - ghfetcher "github.com/cohesion-api/cohesion_backend/pkg/github" + ghpkg "github.com/cohesion-api/cohesion_backend/pkg/github" "github.com/cohesion-api/cohesion_backend/pkg/schemair" "github.com/go-chi/chi/v5" + gh "github.com/google/go-github/v68/github" "github.com/google/uuid" ) // ProxyTarget holds a configured proxy destination. type ProxyTarget struct { - Label string `json:"label"` - TargetURL *url.URL `json:"-"` - RawURL string `json:"target_url"` + Label string `json:"label"` + TargetURL *url.URL `json:"-"` + RawURL string `json:"target_url"` + ResolvedIP string `json:"-"` } type Handlers struct { @@ -34,7 +38,10 @@ type Handlers struct { diffService *services.DiffService liveService *services.LiveService userSettingsService *services.UserSettingsService + ghInstallService *services.GitHubInstallationService analyzer analyzer.Analyzer + githubAppAuth *ghpkg.AppAuth + githubAppSlug string proxyMu sync.RWMutex proxyTargets map[string]map[string]*ProxyTarget // projectID → label → target @@ -47,7 +54,10 @@ func New( diffService *services.DiffService, liveService *services.LiveService, userSettingsService *services.UserSettingsService, + ghInstallService *services.GitHubInstallationService, a analyzer.Analyzer, + githubAppAuth *ghpkg.AppAuth, + githubAppSlug string, ) *Handlers { return &Handlers{ projectService: projectService, @@ -56,7 +66,10 @@ func New( diffService: diffService, liveService: liveService, userSettingsService: userSettingsService, + ghInstallService: ghInstallService, analyzer: a, + githubAppAuth: githubAppAuth, + githubAppSlug: githubAppSlug, proxyTargets: make(map[string]map[string]*ProxyTarget), } } @@ -138,6 +151,10 @@ func (h *Handlers) DeleteProject(w http.ResponseWriter, r *http.Request) { userID := auth.UserID(r.Context()) if err := h.projectService.Delete(r.Context(), projectID, userID); err != nil { + if err == repository.ErrNotFound { + respondError(w, http.StatusNotFound, "Project not found") + return + } respondError(w, http.StatusInternalServerError, "Failed to delete project") return } @@ -274,15 +291,8 @@ func (h *Handlers) ScanCodebase(w http.ResponseWriter, r *http.Request) { schemas, err = ga.AnalyzeFiles(r.Context(), sourceFiles, language, mode) } else if req.DirPath != "" { - files, language := gemini.DiscoverFiles(req.DirPath) - if len(files) == 0 { - respondError(w, http.StatusBadRequest, "No source files found in "+req.DirPath) - return - } - if req.Language != "" { - language = req.Language - } - schemas, err = ga.AnalyzeFiles(r.Context(), files, language, mode) + respondError(w, http.StatusBadRequest, "dir_path is not supported for security reasons. Upload files directly using the 'files' field instead.") + return } else { respondError(w, http.StatusBadRequest, "Either 'files' or 'dir_path' must be provided") return @@ -335,7 +345,7 @@ func (h *Handlers) ScanGitHubRepo(w http.ResponseWriter, r *http.Request) { return } - owner, repo, err := ghfetcher.ParseRepoURL(req.RepoURL) + owner, repo, err := ghpkg.ParseRepoURL(req.RepoURL) if err != nil { respondError(w, http.StatusBadRequest, err.Error()) return @@ -367,12 +377,33 @@ func (h *Handlers) ScanGitHubRepo(w http.ResponseWriter, r *http.Request) { return } - if settings.GitHubToken == "" { - respondError(w, http.StatusBadRequest, "Add a GitHub token in Settings to scan repositories") + var ghClient *gh.Client + if h.githubAppAuth.IsConfigured() { + installations, err := h.ghInstallService.List(r.Context(), userID) + if err == nil { + for _, inst := range installations { + client, err := h.githubAppAuth.InstallationClient(inst.InstallationID) + if err != nil { + continue + } + + _, _, err = client.Repositories.Get(r.Context(), owner, repo) + if err == nil { + ghClient = client + break + } + } + } + } + if ghClient == nil && settings.GitHubToken != "" { + ghClient = gh.NewClient(nil).WithAuthToken(settings.GitHubToken) + } + if ghClient == nil { + respondError(w, http.StatusBadRequest, "Connect a GitHub App or add a Personal Access Token in Settings") return } - files, language, err := ghfetcher.FetchRepoFiles(r.Context(), settings.GitHubToken, owner, repo, req.Branch, req.Path) + files, language, err := ghpkg.FetchRepoFilesWithClient(r.Context(), ghClient, owner, repo, req.Branch, req.Path) if err != nil { respondError(w, http.StatusBadRequest, "GitHub fetch failed: "+err.Error()) return @@ -544,12 +575,18 @@ func (h *Handlers) SaveUserSettings(w http.ResponseWriter, r *http.Request) { return } - existing, _ := h.userSettingsService.Get(r.Context(), userID) - if strings.HasPrefix(req.GeminiAPIKey, "••") && existing.GeminiAPIKey != "" { - req.GeminiAPIKey = existing.GeminiAPIKey + existing, err := h.userSettingsService.Get(r.Context(), userID) + if err != nil { + respondError(w, http.StatusInternalServerError, "Failed to retrieve existing settings") + return } - if strings.HasPrefix(req.GitHubToken, "••") && existing.GitHubToken != "" { - req.GitHubToken = existing.GitHubToken + if existing != nil { + if strings.HasPrefix(req.GeminiAPIKey, "••") && existing.GeminiAPIKey != "" { + req.GeminiAPIKey = existing.GeminiAPIKey + } + if strings.HasPrefix(req.GitHubToken, "••") && existing.GitHubToken != "" { + req.GitHubToken = existing.GitHubToken + } } if err := h.userSettingsService.Save(r.Context(), userID, req.GeminiAPIKey, req.GeminiModel, req.GitHubToken); err != nil { @@ -594,13 +631,18 @@ func maskSecret(s string) string { if len(s) > 4 { return "••••••••" + s[len(s)-4:] } - return s + if len(s) > 0 { + return "••••••••" + } + return "" } func respondJSON(w http.ResponseWriter, status int, data interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - json.NewEncoder(w).Encode(data) + if err := json.NewEncoder(w).Encode(data); err != nil { + log.Printf("ERROR: failed to encode JSON response: %v", err) + } } func respondError(w http.ResponseWriter, status int, message string) { diff --git a/cohesion_backend/internal/controlplane/handlers/live_handlers.go b/cohesion_backend/internal/controlplane/handlers/live_handlers.go index 5028af4..7999124 100644 --- a/cohesion_backend/internal/controlplane/handlers/live_handlers.go +++ b/cohesion_backend/internal/controlplane/handlers/live_handlers.go @@ -2,9 +2,11 @@ package handlers import ( "bytes" + "context" "encoding/json" "fmt" "io" + "net" "net/http" "net/http/httputil" "net/url" @@ -188,22 +190,20 @@ func (h *Handlers) StartCapture(w http.ResponseWriter, r *http.Request) { return } - h.liveService.StartCapture(projectID) + userID := auth.UserID(r.Context()) + h.liveService.StartCapture(projectID, userID) respondJSON(w, http.StatusOK, map[string]string{"message": "Self-capture started"}) } func (h *Handlers) StopCapture(w http.ResponseWriter, r *http.Request) { - active, projectID := h.liveService.IsCapturing() - if active { - userID := auth.UserID(r.Context()) - project, _ := h.projectService.GetByID(r.Context(), projectID, userID) - if project == nil { - respondError(w, http.StatusNotFound, "No active capture for your projects") - return - } + userID := auth.UserID(r.Context()) + active, projectID := h.liveService.IsCapturingForUser(userID) + if !active { + respondJSON(w, http.StatusOK, map[string]string{"message": "No active capture"}) + return } - h.liveService.StopCapture() + h.liveService.StopCapture(projectID) respondJSON(w, http.StatusOK, map[string]string{"message": "Self-capture stopped"}) } @@ -230,6 +230,52 @@ func (h *Handlers) ClearLiveBuffer(w http.ResponseWriter, r *http.Request) { respondJSON(w, http.StatusOK, map[string]string{"message": "Buffer cleared"}) } +func isPrivateIP(ip net.IP) bool { + privateRanges := []string{ + "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", + "127.0.0.0/8", + "169.254.0.0/16", + "::1/128", "fc00::/7", "fe80::/10", + } + for _, cidr := range privateRanges { + _, network, _ := net.ParseCIDR(cidr) + if network.Contains(ip) { + return true + } + } + return false +} + +func validateProxyTarget(targetURL *url.URL) (string, error) { + host := targetURL.Hostname() + if host == "" { + return "", fmt.Errorf("empty host") + } + + if host == "localhost" || host == "0.0.0.0" || host == "[::1]" { + return "", fmt.Errorf("localhost targets are not allowed") + } + + if targetURL.Scheme != "http" && targetURL.Scheme != "https" { + return "", fmt.Errorf("only http and https schemes are allowed") + } + + ips, err := net.LookupIP(host) + if err != nil { + return "", fmt.Errorf("cannot resolve host: %w", err) + } + var resolvedIP string + for _, ip := range ips { + if isPrivateIP(ip) { + return "", fmt.Errorf("target resolves to a private/reserved IP address") + } + if resolvedIP == "" { + resolvedIP = ip.String() + } + } + return resolvedIP, nil +} + // ConfigureProxy sets up a proxy target for a given project and label. func (h *Handlers) ConfigureProxy(w http.ResponseWriter, r *http.Request) { var req struct { @@ -263,14 +309,21 @@ func (h *Handlers) ConfigureProxy(w http.ResponseWriter, r *http.Request) { return } + resolvedIP, err := validateProxyTarget(targetURL) + if err != nil { + respondError(w, http.StatusBadRequest, "Proxy target not allowed: "+err.Error()) + return + } + h.proxyMu.Lock() if h.proxyTargets[req.ProjectID] == nil { h.proxyTargets[req.ProjectID] = make(map[string]*ProxyTarget) } h.proxyTargets[req.ProjectID][req.Label] = &ProxyTarget{ - Label: req.Label, - TargetURL: targetURL, - RawURL: req.TargetURL, + Label: req.Label, + TargetURL: targetURL, + RawURL: req.TargetURL, + ResolvedIP: resolvedIP, } h.proxyMu.Unlock() @@ -330,7 +383,6 @@ func (h *Handlers) ProxyHandler(w http.ResponseWriter, r *http.Request) { start := time.Now() - // Set up reverse proxy proxy := &httputil.ReverseProxy{ Director: func(req *http.Request) { req.URL.Scheme = target.TargetURL.Scheme @@ -339,6 +391,26 @@ func (h *Handlers) ProxyHandler(w http.ResponseWriter, r *http.Request) { req.URL.RawQuery = r.URL.RawQuery req.Host = target.TargetURL.Host }, + Transport: &http.Transport{ + DialContext: (&net.Dialer{}).DialContext, + Proxy: http.ProxyFromEnvironment, + }, + } + if target.ResolvedIP != "" { + pinnedIP := target.ResolvedIP + port := target.TargetURL.Port() + if port == "" { + if target.TargetURL.Scheme == "https" { + port = "443" + } else { + port = "80" + } + } + proxy.Transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, net.JoinHostPort(pinnedIP, port)) + }, + } } // Capture response @@ -459,6 +531,37 @@ func (h *Handlers) LiveDiff(w http.ResponseWriter, r *http.Request) { }) } +func (h *Handlers) GetLiveSchemas(w http.ResponseWriter, r *http.Request) { + projectIDStr := r.URL.Query().Get("project_id") + if projectIDStr == "" { + respondError(w, http.StatusBadRequest, "project_id query parameter is required") + return + } + + projectID, err := uuid.Parse(projectIDStr) + if err != nil { + respondError(w, http.StatusBadRequest, "Invalid project ID") + return + } + + if h.requireProjectAccess(w, r, projectID) == nil { + return + } + + source := r.URL.Query().Get("source") + if source == "" { + respondError(w, http.StatusBadRequest, "source query parameter is required") + return + } + + schemas := h.liveService.InferFromBufferBySource(projectID, source) + if schemas == nil { + schemas = []*schemair.SchemaIR{} + } + + respondJSON(w, http.StatusOK, schemas) +} + // GetLiveSources returns the distinct source labels in the buffer. func (h *Handlers) GetLiveSources(w http.ResponseWriter, r *http.Request) { projectIDStr := r.URL.Query().Get("project_id") diff --git a/cohesion_backend/internal/controlplane/router.go b/cohesion_backend/internal/controlplane/router.go index 2699da3..e90ef97 100644 --- a/cohesion_backend/internal/controlplane/router.go +++ b/cohesion_backend/internal/controlplane/router.go @@ -7,19 +7,24 @@ import ( "github.com/cohesion-api/cohesion_backend/internal/controlplane/handlers" "github.com/cohesion-api/cohesion_backend/internal/services" "github.com/cohesion-api/cohesion_backend/pkg/analyzer" + ghpkg "github.com/cohesion-api/cohesion_backend/pkg/github" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" ) type Services struct { - ProjectService *services.ProjectService - EndpointService *services.EndpointService - SchemaService *services.SchemaService - DiffService *services.DiffService - LiveService *services.LiveService - UserSettingsService *services.UserSettingsService - Analyzer analyzer.Analyzer + ProjectService *services.ProjectService + EndpointService *services.EndpointService + SchemaService *services.SchemaService + DiffService *services.DiffService + LiveService *services.LiveService + UserSettingsService *services.UserSettingsService + GitHubInstallationService *services.GitHubInstallationService + Analyzer analyzer.Analyzer + GitHubAppAuth *ghpkg.AppAuth + GitHubAppSlug string + FrontendURL string } func NewRouter(svc *Services) http.Handler { @@ -37,17 +42,20 @@ func NewRouter(svc *Services) http.Handler { MaxAge: 300, })) - r.Use(svc.LiveService.SelfCaptureMiddleware()) - h := handlers.New( svc.ProjectService, svc.EndpointService, svc.SchemaService, - svc.DiffService, svc.LiveService, svc.UserSettingsService, svc.Analyzer, + svc.DiffService, svc.LiveService, svc.UserSettingsService, + svc.GitHubInstallationService, svc.Analyzer, + svc.GitHubAppAuth, svc.GitHubAppSlug, ) r.Route("/api", func(r chi.Router) { r.Get("/health", h.Health) r.Group(func(r chi.Router) { r.Use(auth.Middleware()) + r.Use(svc.LiveService.SelfCaptureMiddleware(func(r *http.Request) string { + return auth.UserID(r.Context()) + })) r.Route("/projects", func(r chi.Router) { r.Post("/", h.CreateProject) @@ -69,7 +77,7 @@ func NewRouter(svc *Services) http.Handler { r.Get("/{endpointID}", h.GetEndpoint) }) - r.Get("/diff/{endpointID}", h.ComputeDiff) + r.Post("/diff/{endpointID}", h.ComputeDiff) r.Get("/stats", h.GetStats) r.Route("/user", func(r chi.Router) { @@ -77,6 +85,13 @@ func NewRouter(svc *Services) http.Handler { r.Put("/settings", h.SaveUserSettings) }) + r.Route("/github", func(r chi.Router) { + r.Get("/status", h.GitHubAppStatus) + r.Post("/installations", h.SaveGitHubInstallation) + r.Get("/installations", h.ListGitHubInstallations) + r.Delete("/installations/{installationID}", h.RemoveGitHubInstallation) + }) + r.Route("/live", func(r chi.Router) { r.Post("/ingest", h.IngestRuntimeCaptures) r.Get("/stream", h.LiveStream) @@ -86,6 +101,7 @@ func NewRouter(svc *Services) http.Handler { r.Post("/capture/start", h.StartCapture) r.Post("/capture/stop", h.StopCapture) r.Post("/diff", h.LiveDiff) + r.Get("/schemas", h.GetLiveSchemas) r.Get("/sources", h.GetLiveSources) r.Post("/proxy/configure", h.ConfigureProxy) r.HandleFunc("/proxy/{projectID}/{label}/*", h.ProxyHandler) diff --git a/cohesion_backend/internal/crypto/crypto.go b/cohesion_backend/internal/crypto/crypto.go new file mode 100644 index 0000000..58869fd --- /dev/null +++ b/cohesion_backend/internal/crypto/crypto.go @@ -0,0 +1,101 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + "log" + "os" + "sync" +) + +func deriveKey(secret string) []byte { + h := sha256.Sum256([]byte(secret)) + return h[:] +} + +var warnOnce sync.Once + +func getSecret() string { + s := os.Getenv("ENCRYPTION_KEY") + if s == "" { + warnOnce.Do(func() { + log.Println("WARNING: ENCRYPTION_KEY not set — secrets will be stored in plaintext") + }) + } + return s +} + +func Encrypt(plaintext string) (string, error) { + if plaintext == "" { + return "", nil + } + secret := getSecret() + if secret == "" { + return plaintext, nil + } + + block, err := aes.NewCipher(deriveKey(secret)) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) + return "enc:" + base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func Decrypt(ciphertext string) (string, error) { + if ciphertext == "" { + return "", nil + } + if len(ciphertext) < 4 || ciphertext[:4] != "enc:" { + return ciphertext, nil + } + + secret := getSecret() + if secret == "" { + return "", errors.New("ENCRYPTION_KEY not set, cannot decrypt") + } + + data, err := base64.StdEncoding.DecodeString(ciphertext[4:]) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(deriveKey(secret)) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return "", errors.New("ciphertext too short") + } + + nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil) + if err != nil { + return "", err + } + + return string(plaintext), nil +} diff --git a/cohesion_backend/internal/models/models.go b/cohesion_backend/internal/models/models.go index a373194..adea195 100644 --- a/cohesion_backend/internal/models/models.go +++ b/cohesion_backend/internal/models/models.go @@ -52,3 +52,13 @@ type UserSettings struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } + +type GitHubInstallation struct { + ID uuid.UUID `json:"id"` + ClerkUserID string `json:"clerk_user_id"` + InstallationID int64 `json:"installation_id"` + GitHubAccountLogin string `json:"github_account_login"` + GitHubAccountType string `json:"github_account_type"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/cohesion_backend/internal/repository/diff_repo.go b/cohesion_backend/internal/repository/diff_repo.go index 07a52f5..0f53d59 100644 --- a/cohesion_backend/internal/repository/diff_repo.go +++ b/cohesion_backend/internal/repository/diff_repo.go @@ -17,6 +17,8 @@ func NewDiffRepository(db *DB) *DiffRepository { return &DiffRepository{db: db} } +const maxDiffsPerEndpoint = 10 + func (r *DiffRepository) Create(ctx context.Context, diff *models.Diff) error { diff.ID = uuid.New() diff.CreatedAt = time.Now() @@ -25,8 +27,18 @@ func (r *DiffRepository) Create(ctx context.Context, diff *models.Diff) error { INSERT INTO diffs (id, endpoint_id, diff_data, sources_compared, created_at) VALUES ($1, $2, $3, $4, $5) `, diff.ID, diff.EndpointID, diff.DiffData, diff.SourcesCompared, diff.CreatedAt) + if err != nil { + return err + } + _, _ = r.db.Pool.Exec(ctx, ` + DELETE FROM diffs WHERE id IN ( + SELECT id FROM diffs WHERE endpoint_id = $1 + ORDER BY created_at DESC + OFFSET $2 + ) + `, diff.EndpointID, maxDiffsPerEndpoint) - return err + return nil } func (r *DiffRepository) GetByEndpointID(ctx context.Context, endpointID uuid.UUID) ([]models.Diff, error) { diff --git a/cohesion_backend/internal/repository/endpoint_repo.go b/cohesion_backend/internal/repository/endpoint_repo.go index 18ec70c..818f22e 100644 --- a/cohesion_backend/internal/repository/endpoint_repo.go +++ b/cohesion_backend/internal/repository/endpoint_repo.go @@ -21,18 +21,22 @@ func (r *EndpointRepository) Upsert(ctx context.Context, endpoint *models.Endpoi if endpoint.ID == uuid.Nil { endpoint.ID = uuid.New() } - endpoint.CreatedAt = time.Now() - endpoint.UpdatedAt = time.Now() + now := time.Now() + endpoint.UpdatedAt = now - _, err := r.db.Pool.Exec(ctx, ` + var returnedID uuid.UUID + err := r.db.Pool.QueryRow(ctx, ` INSERT INTO endpoints (id, project_id, path, method, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (project_id, path, method) + ON CONFLICT (project_id, path, method) DO UPDATE SET updated_at = EXCLUDED.updated_at RETURNING id - `, endpoint.ID, endpoint.ProjectID, endpoint.Path, endpoint.Method, endpoint.CreatedAt, endpoint.UpdatedAt) - - return err + `, endpoint.ID, endpoint.ProjectID, endpoint.Path, endpoint.Method, now, now).Scan(&returnedID) + if err != nil { + return err + } + endpoint.ID = returnedID + return nil } func (r *EndpointRepository) GetByID(ctx context.Context, id uuid.UUID) (*models.Endpoint, error) { @@ -132,6 +136,108 @@ func (r *EndpointRepository) GetByProjectWithSchemas(ctx context.Context, projec return endpoints, nil } +func (r *EndpointRepository) GetByProjectIDsWithSchemas(ctx context.Context, projectIDs []uuid.UUID) ([]models.Endpoint, error) { + if len(projectIDs) == 0 { + return nil, nil + } + + rows, err := r.db.Pool.Query(ctx, ` + SELECT e.id, e.project_id, e.path, e.method, e.created_at, e.updated_at, + s.id, s.source, s.schema_data, s.version, s.created_at, s.updated_at + FROM endpoints e + LEFT JOIN schemas s ON s.endpoint_id = e.id + WHERE e.project_id = ANY($1) + ORDER BY e.path, e.method, s.source + `, projectIDs) + if err != nil { + return nil, err + } + defer rows.Close() + + endpointMap := make(map[uuid.UUID]*models.Endpoint) + var orderedIDs []uuid.UUID + + for rows.Next() { + var e models.Endpoint + var schemaID, schemaSource *string + var schemaData *map[string]interface{} + var schemaVersion *int + var schemaCreatedAt, schemaUpdatedAt *time.Time + + if err := rows.Scan( + &e.ID, &e.ProjectID, &e.Path, &e.Method, &e.CreatedAt, &e.UpdatedAt, + &schemaID, &schemaSource, &schemaData, &schemaVersion, &schemaCreatedAt, &schemaUpdatedAt, + ); err != nil { + return nil, err + } + + if _, exists := endpointMap[e.ID]; !exists { + e.Schemas = []models.Schema{} + endpointMap[e.ID] = &e + orderedIDs = append(orderedIDs, e.ID) + } + + if schemaID != nil { + schema := models.Schema{ + ID: uuid.MustParse(*schemaID), + EndpointID: e.ID, + Source: *schemaSource, + SchemaData: *schemaData, + Version: *schemaVersion, + CreatedAt: *schemaCreatedAt, + UpdatedAt: *schemaUpdatedAt, + } + endpointMap[e.ID].Schemas = append(endpointMap[e.ID].Schemas, schema) + } + } + + if err := rows.Err(); err != nil { + return nil, err + } + + endpoints := make([]models.Endpoint, 0, len(orderedIDs)) + for _, id := range orderedIDs { + endpoints = append(endpoints, *endpointMap[id]) + } + + return endpoints, nil +} + +func (r *EndpointRepository) UpsertBatch(ctx context.Context, endpoints []*models.Endpoint) error { + if len(endpoints) == 0 { + return nil + } + + now := time.Now() + batch := &pgx.Batch{} + + for _, ep := range endpoints { + if ep.ID == uuid.Nil { + ep.ID = uuid.New() + } + ep.UpdatedAt = now + batch.Queue(` + INSERT INTO endpoints (id, project_id, path, method, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (project_id, path, method) + DO UPDATE SET updated_at = EXCLUDED.updated_at + RETURNING id + `, ep.ID, ep.ProjectID, ep.Path, ep.Method, now, now) + } + + br := r.db.Pool.SendBatch(ctx, batch) + defer br.Close() + + for _, ep := range endpoints { + var returnedID uuid.UUID + if err := br.QueryRow().Scan(&returnedID); err != nil { + return err + } + ep.ID = returnedID + } + return nil +} + func (r *EndpointRepository) GetByPathAndMethod(ctx context.Context, projectID uuid.UUID, path, method string) (*models.Endpoint, error) { var endpoint models.Endpoint err := r.db.Pool.QueryRow(ctx, ` diff --git a/cohesion_backend/internal/repository/github_installation_repo.go b/cohesion_backend/internal/repository/github_installation_repo.go new file mode 100644 index 0000000..605e990 --- /dev/null +++ b/cohesion_backend/internal/repository/github_installation_repo.go @@ -0,0 +1,73 @@ +package repository + +import ( + "context" + "time" + + "github.com/cohesion-api/cohesion_backend/internal/models" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" +) + +type GitHubInstallationRepository struct { + db *DB +} + +func NewGitHubInstallationRepository(db *DB) *GitHubInstallationRepository { + return &GitHubInstallationRepository{db: db} +} + +func (r *GitHubInstallationRepository) ListByClerkUserID(ctx context.Context, clerkUserID string) ([]models.GitHubInstallation, error) { + rows, err := r.db.Pool.Query(ctx, ` + SELECT id, clerk_user_id, installation_id, github_account_login, github_account_type, created_at, updated_at + FROM github_installations WHERE clerk_user_id = $1 ORDER BY created_at + `, clerkUserID) + if err != nil { + return nil, err + } + defer rows.Close() + + var installations []models.GitHubInstallation + for rows.Next() { + var i models.GitHubInstallation + if err := rows.Scan(&i.ID, &i.ClerkUserID, &i.InstallationID, &i.GitHubAccountLogin, &i.GitHubAccountType, &i.CreatedAt, &i.UpdatedAt); err != nil { + return nil, err + } + installations = append(installations, i) + } + return installations, rows.Err() +} + +func (r *GitHubInstallationRepository) Upsert(ctx context.Context, inst *models.GitHubInstallation) error { + now := time.Now() + inst.UpdatedAt = now + + if inst.ID == uuid.Nil { + inst.ID = uuid.New() + inst.CreatedAt = now + } + + _, err := r.db.Pool.Exec(ctx, ` + INSERT INTO github_installations (id, clerk_user_id, installation_id, github_account_login, github_account_type, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (clerk_user_id, installation_id) DO UPDATE SET + github_account_login = EXCLUDED.github_account_login, + github_account_type = EXCLUDED.github_account_type, + updated_at = EXCLUDED.updated_at + `, inst.ID, inst.ClerkUserID, inst.InstallationID, inst.GitHubAccountLogin, inst.GitHubAccountType, inst.CreatedAt, inst.UpdatedAt) + + return err +} + +func (r *GitHubInstallationRepository) DeleteByInstallationID(ctx context.Context, clerkUserID string, installationID int64) error { + result, err := r.db.Pool.Exec(ctx, ` + DELETE FROM github_installations WHERE clerk_user_id = $1 AND installation_id = $2 + `, clerkUserID, installationID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return pgx.ErrNoRows + } + return nil +} diff --git a/cohesion_backend/internal/repository/project_repo.go b/cohesion_backend/internal/repository/project_repo.go index e230308..0052dfd 100644 --- a/cohesion_backend/internal/repository/project_repo.go +++ b/cohesion_backend/internal/repository/project_repo.go @@ -2,6 +2,7 @@ package repository import ( "context" + "fmt" "time" "github.com/cohesion-api/cohesion_backend/internal/models" @@ -64,7 +65,15 @@ func (r *ProjectRepository) List(ctx context.Context, ownerID string) ([]models. return projects, rows.Err() } +var ErrNotFound = fmt.Errorf("not found") + func (r *ProjectRepository) Delete(ctx context.Context, id uuid.UUID, ownerID string) error { - _, err := r.db.Pool.Exec(ctx, `DELETE FROM projects WHERE id = $1 AND owner_id = $2`, id, ownerID) - return err + result, err := r.db.Pool.Exec(ctx, `DELETE FROM projects WHERE id = $1 AND owner_id = $2`, id, ownerID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + return nil } diff --git a/cohesion_backend/internal/repository/user_settings_repo.go b/cohesion_backend/internal/repository/user_settings_repo.go index 7bc6db3..03cedea 100644 --- a/cohesion_backend/internal/repository/user_settings_repo.go +++ b/cohesion_backend/internal/repository/user_settings_repo.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/cohesion-api/cohesion_backend/internal/crypto" "github.com/cohesion-api/cohesion_backend/internal/models" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -27,7 +28,18 @@ func (r *UserSettingsRepository) GetByClerkUserID(ctx context.Context, clerkUser if err == pgx.ErrNoRows { return nil, nil } - return &s, err + if err != nil { + return nil, err + } + + if decrypted, err := crypto.Decrypt(s.GeminiAPIKey); err == nil { + s.GeminiAPIKey = decrypted + } + if decrypted, err := crypto.Decrypt(s.GitHubToken); err == nil { + s.GitHubToken = decrypted + } + + return &s, nil } func (r *UserSettingsRepository) Upsert(ctx context.Context, settings *models.UserSettings) error { @@ -39,7 +51,16 @@ func (r *UserSettingsRepository) Upsert(ctx context.Context, settings *models.Us settings.CreatedAt = now } - _, err := r.db.Pool.Exec(ctx, ` + encGemini, err := crypto.Encrypt(settings.GeminiAPIKey) + if err != nil { + return err + } + encGitHub, err := crypto.Encrypt(settings.GitHubToken) + if err != nil { + return err + } + + _, err = r.db.Pool.Exec(ctx, ` INSERT INTO user_settings (id, clerk_user_id, gemini_api_key, gemini_model, github_token, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (clerk_user_id) DO UPDATE SET @@ -47,7 +68,7 @@ func (r *UserSettingsRepository) Upsert(ctx context.Context, settings *models.Us gemini_model = EXCLUDED.gemini_model, github_token = EXCLUDED.github_token, updated_at = EXCLUDED.updated_at - `, settings.ID, settings.ClerkUserID, settings.GeminiAPIKey, settings.GeminiModel, settings.GitHubToken, settings.CreatedAt, settings.UpdatedAt) + `, settings.ID, settings.ClerkUserID, encGemini, settings.GeminiModel, encGitHub, settings.CreatedAt, settings.UpdatedAt) return err } diff --git a/cohesion_backend/internal/services/diff_service.go b/cohesion_backend/internal/services/diff_service.go index a730902..6fdb8d2 100644 --- a/cohesion_backend/internal/services/diff_service.go +++ b/cohesion_backend/internal/services/diff_service.go @@ -3,6 +3,8 @@ package services import ( "context" "encoding/json" + "fmt" + "log" "github.com/cohesion-api/cohesion_backend/internal/models" "github.com/cohesion-api/cohesion_backend/internal/repository" @@ -27,7 +29,7 @@ func NewDiffService(diffRepo *repository.DiffRepository, schemaRepo *repository. } } -func schemasToIR(schemas []models.Schema) []schemair.SchemaIR { +func schemasToIR(schemas []models.Schema) ([]schemair.SchemaIR, []string) { seen := make(map[string]bool) var deduped []models.Schema for _, schema := range schemas { @@ -38,14 +40,19 @@ func schemasToIR(schemas []models.Schema) []schemair.SchemaIR { deduped = append(deduped, schema) } + var warnings []string result := make([]schemair.SchemaIR, 0, len(deduped)) for _, schema := range deduped { var ir schemair.SchemaIR data, err := json.Marshal(schema.SchemaData) if err != nil { + log.Printf("WARNING: failed to marshal schema %s (source=%s): %v", schema.ID, schema.Source, err) + warnings = append(warnings, fmt.Sprintf("Skipped corrupt schema from %s", schema.Source)) continue } if err := json.Unmarshal(data, &ir); err != nil { + log.Printf("WARNING: failed to unmarshal schema %s (source=%s): %v", schema.ID, schema.Source, err) + warnings = append(warnings, fmt.Sprintf("Skipped unparseable schema from %s", schema.Source)) continue } if ir.Source == "" { @@ -53,7 +60,7 @@ func schemasToIR(schemas []models.Schema) []schemair.SchemaIR { } result = append(result, ir) } - return result + return result, warnings } func (s *DiffService) ComputeDiff(ctx context.Context, endpointID uuid.UUID) (*diff.Result, error) { @@ -67,34 +74,28 @@ func (s *DiffService) ComputeDiff(ctx context.Context, endpointID uuid.UUID) (*d return nil, err } - if len(schemas) < 2 { - return &diff.Result{ - Endpoint: endpoint.Path, - Method: endpoint.Method, - Status: schemair.StatusMatch, - Mismatches: []diff.Mismatch{}, - }, nil - } - - schemaIRs := schemasToIR(schemas) + // Always go through the engine so we get a proper Confidence object + schemaIRs, _ := schemasToIR(schemas) result := s.diffEngine.Compare(endpoint.Path, endpoint.Method, schemaIRs) - diffData := map[string]interface{}{ - "endpoint": result.Endpoint, - "method": result.Method, - "status": result.Status, - "mismatches": result.Mismatches, - "sources_compared": result.SourcesCompared, - } + if len(schemaIRs) >= 2 { + diffData := map[string]interface{}{ + "endpoint": result.Endpoint, + "method": result.Method, + "status": result.Status, + "mismatches": result.Mismatches, + "sources_compared": result.SourcesCompared, + } - diffModel := &models.Diff{ - EndpointID: endpointID, - DiffData: diffData, - SourcesCompared: formatSources(result.SourcesCompared), - } + diffModel := &models.Diff{ + EndpointID: endpointID, + DiffData: diffData, + SourcesCompared: formatSources(result.SourcesCompared), + } - if err := s.diffRepo.Create(ctx, diffModel); err != nil { - return nil, err + if err := s.diffRepo.Create(ctx, diffModel); err != nil { + return nil, err + } } return result, nil @@ -109,35 +110,37 @@ type DiffStats struct { func (s *DiffService) ComputeStats(ctx context.Context, projectIDs []uuid.UUID) (*DiffStats, error) { stats := &DiffStats{} - for _, projectID := range projectIDs { - endpoints, err := s.endpointRepo.GetByProjectWithSchemas(ctx, projectID) - if err != nil { - return nil, err + if len(projectIDs) == 0 { + return stats, nil + } + + endpoints, err := s.endpointRepo.GetByProjectIDsWithSchemas(ctx, projectIDs) + if err != nil { + return nil, err + } + + for _, endpoint := range endpoints { + sources := make(map[string]bool) + for _, schema := range endpoint.Schemas { + sources[schema.Source] = true + } + if len(sources) < 2 { + continue + } + + schemaIRs, _ := schemasToIR(endpoint.Schemas) + if len(schemaIRs) < 2 { + continue } - for _, endpoint := range endpoints { - sources := make(map[string]bool) - for _, schema := range endpoint.Schemas { - sources[schema.Source] = true - } - if len(sources) < 2 { - continue - } - - schemaIRs := schemasToIR(endpoint.Schemas) - if len(schemaIRs) < 2 { - continue - } - - result := s.diffEngine.Compare(endpoint.Path, endpoint.Method, schemaIRs) - switch result.Status { - case schemair.StatusMatch: - stats.Matched++ - case schemair.StatusPartial: - stats.Partial++ - case schemair.StatusViolation: - stats.Violations++ - } + result := s.diffEngine.Compare(endpoint.Path, endpoint.Method, schemaIRs) + switch result.Status { + case schemair.StatusMatch: + stats.Matched++ + case schemair.StatusPartial: + stats.Partial++ + case schemair.StatusViolation: + stats.Violations++ } } return stats, nil diff --git a/cohesion_backend/internal/services/endpoint_service.go b/cohesion_backend/internal/services/endpoint_service.go index 5b9e62d..a110d59 100644 --- a/cohesion_backend/internal/services/endpoint_service.go +++ b/cohesion_backend/internal/services/endpoint_service.go @@ -40,23 +40,13 @@ func (s *EndpointService) ListByProject(ctx context.Context, projectID uuid.UUID } func (s *EndpointService) GetOrCreate(ctx context.Context, projectID uuid.UUID, path, method string) (*models.Endpoint, error) { - existing, err := s.endpointRepo.GetByPathAndMethod(ctx, projectID, path, method) - if err != nil { - return nil, err - } - if existing != nil { - return existing, nil - } - endpoint := &models.Endpoint{ ProjectID: projectID, Path: path, Method: method, } - if err := s.endpointRepo.Upsert(ctx, endpoint); err != nil { return nil, err } - return endpoint, nil } diff --git a/cohesion_backend/internal/services/github_installation_service.go b/cohesion_backend/internal/services/github_installation_service.go new file mode 100644 index 0000000..624bd76 --- /dev/null +++ b/cohesion_backend/internal/services/github_installation_service.go @@ -0,0 +1,41 @@ +package services + +import ( + "context" + + "github.com/cohesion-api/cohesion_backend/internal/models" + "github.com/cohesion-api/cohesion_backend/internal/repository" +) + +type GitHubInstallationService struct { + repo *repository.GitHubInstallationRepository +} + +func NewGitHubInstallationService(repo *repository.GitHubInstallationRepository) *GitHubInstallationService { + return &GitHubInstallationService{repo: repo} +} + +func (s *GitHubInstallationService) List(ctx context.Context, clerkUserID string) ([]models.GitHubInstallation, error) { + installations, err := s.repo.ListByClerkUserID(ctx, clerkUserID) + if err != nil { + return nil, err + } + if installations == nil { + return []models.GitHubInstallation{}, nil + } + return installations, nil +} + +func (s *GitHubInstallationService) SaveInstallation(ctx context.Context, clerkUserID string, installationID int64, accountLogin, accountType string) error { + inst := &models.GitHubInstallation{ + ClerkUserID: clerkUserID, + InstallationID: installationID, + GitHubAccountLogin: accountLogin, + GitHubAccountType: accountType, + } + return s.repo.Upsert(ctx, inst) +} + +func (s *GitHubInstallationService) Remove(ctx context.Context, clerkUserID string, installationID int64) error { + return s.repo.DeleteByInstallationID(ctx, clerkUserID, installationID) +} diff --git a/cohesion_backend/internal/services/live_service.go b/cohesion_backend/internal/services/live_service.go index e90d8ec..e37b2fd 100644 --- a/cohesion_backend/internal/services/live_service.go +++ b/cohesion_backend/internal/services/live_service.go @@ -33,49 +33,93 @@ type LiveEvent struct { } type projectBuffer struct { - requests []LiveRequest - maxSize int + data []LiveRequest + maxSize int + head int + count int } func (b *projectBuffer) add(req LiveRequest) { - if len(b.requests) >= b.maxSize { - b.requests = b.requests[1:] + if b.count < b.maxSize { + b.data = append(b.data, req) + b.count++ + b.head = b.count % b.maxSize + } else { + b.data[b.head] = req + b.head = (b.head + 1) % b.maxSize } - b.requests = append(b.requests, req) +} + +func (b *projectBuffer) all() []LiveRequest { + if b.count < b.maxSize { + result := make([]LiveRequest, b.count) + copy(result, b.data[:b.count]) + return result + } + result := make([]LiveRequest, b.maxSize) + copy(result, b.data[b.head:]) + copy(result[b.maxSize-b.head:], b.data[:b.head]) + return result +} + +func (b *projectBuffer) clear() { + b.data = b.data[:0] + b.head = 0 + b.count = 0 +} + +type captureEntry struct { + ownerID string } type LiveService struct { - mu sync.RWMutex - buffers map[uuid.UUID]*projectBuffer - subscribers map[uuid.UUID]map[chan LiveEvent]struct{} - maxPerProj int - captureProjectID uuid.UUID + mu sync.RWMutex + buffers map[uuid.UUID]*projectBuffer + subscribers map[uuid.UUID]map[chan LiveEvent]struct{} + maxPerProj int + captures map[uuid.UUID]captureEntry } func NewLiveService() *LiveService { return &LiveService{ buffers: make(map[uuid.UUID]*projectBuffer), subscribers: make(map[uuid.UUID]map[chan LiveEvent]struct{}), + captures: make(map[uuid.UUID]captureEntry), maxPerProj: 200, } } -func (s *LiveService) StartCapture(projectID uuid.UUID) { +func (s *LiveService) StartCapture(projectID uuid.UUID, ownerID string) { s.mu.Lock() defer s.mu.Unlock() - s.captureProjectID = projectID + s.captures[projectID] = captureEntry{ownerID: ownerID} } -func (s *LiveService) StopCapture() { +func (s *LiveService) StopCapture(projectID uuid.UUID) { s.mu.Lock() defer s.mu.Unlock() - s.captureProjectID = uuid.Nil + delete(s.captures, projectID) } -func (s *LiveService) IsCapturing() (bool, uuid.UUID) { +func (s *LiveService) IsCapturingForUser(userID string) (bool, uuid.UUID) { s.mu.RLock() defer s.mu.RUnlock() - return s.captureProjectID != uuid.Nil, s.captureProjectID + for projectID, entry := range s.captures { + if entry.ownerID == userID { + return true, projectID + } + } + return false, uuid.Nil +} + +func (s *LiveService) IsProjectCapturing(projectID uuid.UUID) (bool, string) { + s.mu.RLock() + defer s.mu.RUnlock() + entry, ok := s.captures[projectID] + if !ok { + return false, "" + } + return true, entry.ownerID } type captureResponseWriter struct { @@ -94,10 +138,16 @@ func (w *captureResponseWriter) Write(b []byte) (int, error) { return w.ResponseWriter.Write(b) } -func (s *LiveService) SelfCaptureMiddleware() func(http.Handler) http.Handler { +func (s *LiveService) SelfCaptureMiddleware(userIDFunc func(r *http.Request) string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - active, projectID := s.IsCapturing() + userID := userIDFunc(r) + if userID == "" { + next.ServeHTTP(w, r) + return + } + + active, projectID := s.IsCapturingForUser(userID) if !active { next.ServeHTTP(w, r) return @@ -147,20 +197,19 @@ func (s *LiveService) SelfCaptureMiddleware() func(http.Handler) http.Handler { } } -func (s *LiveService) IngestRequests(projectID uuid.UUID, requests []LiveRequest) []runtime.CapturedRequest { +func (s *LiveService) IngestRequests(projectID uuid.UUID, requests []LiveRequest) { s.mu.Lock() defer s.mu.Unlock() buf, ok := s.buffers[projectID] if !ok { buf = &projectBuffer{ - requests: make([]LiveRequest, 0, s.maxPerProj), - maxSize: s.maxPerProj, + data: make([]LiveRequest, 0, s.maxPerProj), + maxSize: s.maxPerProj, } s.buffers[projectID] = buf } - var captured []runtime.CapturedRequest for i := range requests { req := &requests[i] if req.ID == "" { @@ -171,23 +220,12 @@ func (s *LiveService) IngestRequests(projectID uuid.UUID, requests []LiveRequest } buf.add(*req) - captured = append(captured, runtime.CapturedRequest{ - Path: req.Path, - Method: req.Method, - RequestBody: req.RequestBody, - StatusCode: req.StatusCode, - Response: req.ResponseBody, - ObservationCount: 1, - }) - s.broadcast(projectID, LiveEvent{ Type: "request", Payload: *req, Source: req.Source, }) } - - return captured } func (s *LiveService) GetRecentRequests(projectID uuid.UUID) []LiveRequest { @@ -199,9 +237,7 @@ func (s *LiveService) GetRecentRequests(projectID uuid.UUID) []LiveRequest { return []LiveRequest{} } - result := make([]LiveRequest, len(buf.requests)) - copy(result, buf.requests) - return result + return buf.all() } func (s *LiveService) GetBufferedAsCaptured(projectID uuid.UUID) []runtime.CapturedRequest { @@ -213,16 +249,17 @@ func (s *LiveService) GetBufferedAsCaptured(projectID uuid.UUID) []runtime.Captu return nil } - var result []runtime.CapturedRequest - for _, req := range buf.requests { - result = append(result, runtime.CapturedRequest{ + requests := buf.all() + result := make([]runtime.CapturedRequest, len(requests)) + for i, req := range requests { + result[i] = runtime.CapturedRequest{ Path: req.Path, Method: req.Method, RequestBody: req.RequestBody, StatusCode: req.StatusCode, Response: req.ResponseBody, ObservationCount: 1, - }) + } } return result } @@ -245,7 +282,7 @@ func (s *LiveService) GetBufferedBySource(projectID uuid.UUID, source string) [] } var result []LiveRequest - for _, req := range buf.requests { + for _, req := range buf.all() { if req.Source == source { result = append(result, req) } @@ -263,7 +300,7 @@ func (s *LiveService) GetBufferedAsCapturedBySource(projectID uuid.UUID, source } var result []runtime.CapturedRequest - for _, req := range buf.requests { + for _, req := range buf.all() { if req.Source == source { result = append(result, runtime.CapturedRequest{ Path: req.Path, @@ -296,7 +333,7 @@ func (s *LiveService) GetDistinctSources(projectID uuid.UUID) []string { } seen := make(map[string]struct{}) - for _, req := range buf.requests { + for _, req := range buf.all() { if req.Source != "" { seen[req.Source] = struct{}{} } @@ -313,9 +350,7 @@ func (s *LiveService) ClearBuffer(projectID uuid.UUID) { s.mu.Lock() defer s.mu.Unlock() - if buf, ok := s.buffers[projectID]; ok { - buf.requests = buf.requests[:0] - } + delete(s.buffers, projectID) s.broadcast(projectID, LiveEvent{Type: "clear"}) } @@ -347,13 +382,6 @@ func (s *LiveService) broadcast(projectID uuid.UUID, event LiveEvent) { if !ok { return } - data, err := json.Marshal(event) - if err != nil { - return - } - var evt LiveEvent - json.Unmarshal(data, &evt) - for ch := range subs { select { case ch <- event: diff --git a/cohesion_backend/internal/services/schema_service.go b/cohesion_backend/internal/services/schema_service.go index 659c621..47b03a8 100644 --- a/cohesion_backend/internal/services/schema_service.go +++ b/cohesion_backend/internal/services/schema_service.go @@ -30,14 +30,37 @@ type SchemaUploadRequest struct { } func (s *SchemaService) UploadSchemas(ctx context.Context, projectID uuid.UUID, schemas []schemair.SchemaIR) error { - var dbSchemas []models.Schema + if len(schemas) == 0 { + return nil + } + + type epKey struct{ path, method string } + epMap := make(map[epKey]*models.Endpoint) + var uniqueEndpoints []*models.Endpoint for _, schema := range schemas { normalizedPath := s.normalizePath(schema.Endpoint) - endpoint, err := s.getOrCreateEndpoint(ctx, projectID, normalizedPath, schema.Method) - if err != nil { - return err + key := epKey{normalizedPath, schema.Method} + if _, exists := epMap[key]; !exists { + ep := &models.Endpoint{ + ProjectID: projectID, + Path: normalizedPath, + Method: schema.Method, + } + epMap[key] = ep + uniqueEndpoints = append(uniqueEndpoints, ep) } + } + + if err := s.endpointRepo.UpsertBatch(ctx, uniqueEndpoints); err != nil { + return err + } + + dbSchemas := make([]models.Schema, 0, len(schemas)) + for _, schema := range schemas { + normalizedPath := s.normalizePath(schema.Endpoint) + key := epKey{normalizedPath, schema.Method} + endpoint := epMap[key] schemaData := map[string]interface{}{ "endpoint": normalizedPath, @@ -54,11 +77,7 @@ func (s *SchemaService) UploadSchemas(ctx context.Context, projectID uuid.UUID, }) } - if len(dbSchemas) > 0 { - return s.schemaRepo.UpsertBatch(ctx, dbSchemas) - } - - return nil + return s.schemaRepo.UpsertBatch(ctx, dbSchemas) } func (s *SchemaService) normalizePath(path string) string { @@ -75,28 +94,6 @@ func (s *SchemaService) normalizePath(path string) string { return path } -func (s *SchemaService) getOrCreateEndpoint(ctx context.Context, projectID uuid.UUID, path, method string) (*models.Endpoint, error) { - existing, err := s.endpointRepo.GetByPathAndMethod(ctx, projectID, path, method) - if err != nil { - return nil, err - } - if existing != nil { - return existing, nil - } - - endpoint := &models.Endpoint{ - ProjectID: projectID, - Path: path, - Method: method, - } - - if err := s.endpointRepo.Upsert(ctx, endpoint); err != nil { - return nil, err - } - - return s.endpointRepo.GetByPathAndMethod(ctx, projectID, path, method) -} - func (s *SchemaService) GetByEndpoint(ctx context.Context, endpointID uuid.UUID) ([]models.Schema, error) { return s.schemaRepo.GetByEndpointID(ctx, endpointID) } diff --git a/cohesion_backend/migrations/005_github_app_installations.down.sql b/cohesion_backend/migrations/005_github_app_installations.down.sql new file mode 100644 index 0000000..156781a --- /dev/null +++ b/cohesion_backend/migrations/005_github_app_installations.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS github_installations; diff --git a/cohesion_backend/migrations/005_github_app_installations.up.sql b/cohesion_backend/migrations/005_github_app_installations.up.sql new file mode 100644 index 0000000..32e166e --- /dev/null +++ b/cohesion_backend/migrations/005_github_app_installations.up.sql @@ -0,0 +1,11 @@ +CREATE TABLE github_installations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + clerk_user_id VARCHAR(255) NOT NULL, + installation_id BIGINT NOT NULL, + github_account_login VARCHAR(255) NOT NULL DEFAULT '', + github_account_type VARCHAR(50) NOT NULL DEFAULT '', + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW(), + UNIQUE(clerk_user_id, installation_id) +); +CREATE INDEX idx_github_installations_clerk_user_id ON github_installations(clerk_user_id); diff --git a/cohesion_backend/pkg/analyzer/gemini/filediscovery.go b/cohesion_backend/pkg/analyzer/gemini/filediscovery.go index 6178834..76789af 100644 --- a/cohesion_backend/pkg/analyzer/gemini/filediscovery.go +++ b/cohesion_backend/pkg/analyzer/gemini/filediscovery.go @@ -7,6 +7,8 @@ import ( "strings" ignore "github.com/sabhiram/go-gitignore" + + "github.com/cohesion-api/cohesion_backend/pkg/sourcefile" ) type SourceFile struct { @@ -14,56 +16,6 @@ type SourceFile struct { Content string } -var skipDirs = map[string]bool{ - "vendor": true, - ".git": true, - "node_modules": true, - "__pycache__": true, - ".venv": true, - "venv": true, - "dist": true, - "build": true, - "target": true, - ".idea": true, - ".vscode": true, - ".next": true, - ".nuxt": true, -} - -var sourceExtensions = map[string]bool{ - ".go": true, - ".py": true, - ".ts": true, - ".js": true, - ".java": true, - ".rb": true, - ".rs": true, - ".php": true, - ".cs": true, - ".kt": true, - ".ex": true, - ".exs": true, - ".scala": true, - ".swift": true, -} - -var languageHints = map[string]string{ - ".go": "Go", - ".py": "Python", - ".ts": "TypeScript", - ".js": "JavaScript", - ".java": "Java", - ".rb": "Ruby", - ".rs": "Rust", - ".php": "PHP", - ".cs": "C#", - ".kt": "Kotlin", - ".ex": "Elixir", - ".exs": "Elixir", - ".scala": "Scala", - ".swift": "Swift", -} - const maxTokenBudget = 900_000 const bytesPerToken = 4 @@ -94,20 +46,6 @@ func filePriority(name string) int { return 4 } -func isTestFile(name string) bool { - lower := strings.ToLower(name) - return strings.HasSuffix(lower, "_test.go") || - strings.HasPrefix(lower, "test_") || - strings.HasSuffix(lower, "_test.py") || - strings.HasSuffix(lower, ".test.ts") || - strings.HasSuffix(lower, ".test.js") || - strings.HasSuffix(lower, ".spec.ts") || - strings.HasSuffix(lower, ".spec.js") || - strings.Contains(lower, "/test/") || - strings.Contains(lower, "/tests/") || - strings.Contains(lower, "/__tests__/") -} - type fileEntry struct { path string relPath string @@ -167,7 +105,7 @@ func loadGitIgnoreMatcher(rootPath string) *ignore.GitIgnore { } if info.IsDir() { - if skipDirs[info.Name()] { + if sourcefile.SkipDirs[info.Name()] { return filepath.SkipDir } return nil @@ -206,7 +144,7 @@ func DetectLanguage(files []SourceFile) string { extCount := make(map[string]int) for _, f := range files { ext := strings.ToLower(filepath.Ext(f.Path)) - if languageHints[ext] != "" { + if sourcefile.LanguageHints[ext] != "" { extCount[ext]++ } } @@ -216,7 +154,7 @@ func DetectLanguage(files []SourceFile) string { for ext, count := range extCount { if count > maxCount { maxCount = count - language = languageHints[ext] + language = sourcefile.LanguageHints[ext] } } return language @@ -239,7 +177,7 @@ func DiscoverFiles(rootPath string) ([]SourceFile, string) { relPath = filepath.ToSlash(relPath) if info.IsDir() { - if skipDirs[info.Name()] { + if sourcefile.SkipDirs[info.Name()] { return filepath.SkipDir } if relPath != "." && gitIgnoreMatcher != nil && gitIgnoreMatcher.MatchesPath(relPath+"/") { @@ -253,11 +191,11 @@ func DiscoverFiles(rootPath string) ([]SourceFile, string) { } ext := strings.ToLower(filepath.Ext(info.Name())) - if !sourceExtensions[ext] { + if !sourcefile.SourceExtensions[ext] { return nil } - if isTestFile(relPath) { + if sourcefile.IsTestFile(relPath) { return nil } @@ -305,7 +243,7 @@ func DiscoverFiles(rootPath string) ([]SourceFile, string) { for ext, count := range extCount { if count > maxCount { maxCount = count - language = languageHints[ext] + language = sourcefile.LanguageHints[ext] } } diff --git a/cohesion_backend/pkg/analyzer/interface.go b/cohesion_backend/pkg/analyzer/interface.go index 5035fb2..566a67d 100644 --- a/cohesion_backend/pkg/analyzer/interface.go +++ b/cohesion_backend/pkg/analyzer/interface.go @@ -11,31 +11,3 @@ type Analyzer interface { Framework() string Analyze(ctx context.Context, sourcePath string) ([]*schemair.SchemaIR, error) } - -type Registry struct { - analyzers map[string]Analyzer -} - -func NewRegistry() *Registry { - return &Registry{ - analyzers: make(map[string]Analyzer), - } -} - -func (r *Registry) Register(a Analyzer) { - key := a.Language() + ":" + a.Framework() - r.analyzers[key] = a -} - -func (r *Registry) Get(language, framework string) (Analyzer, bool) { - a, ok := r.analyzers[language+":"+framework] - return a, ok -} - -func (r *Registry) List() []Analyzer { - result := make([]Analyzer, 0, len(r.analyzers)) - for _, a := range r.analyzers { - result = append(result, a) - } - return result -} diff --git a/cohesion_backend/pkg/diff/engine.go b/cohesion_backend/pkg/diff/engine.go index 814a5ba..019ff13 100644 --- a/cohesion_backend/pkg/diff/engine.go +++ b/cohesion_backend/pkg/diff/engine.go @@ -2,6 +2,7 @@ package diff import ( "fmt" + "sort" "strings" "unicode" @@ -15,13 +16,34 @@ func NewEngine() *Engine { } var typeCompatGroups = [][]string{ - {"int", "integer", "number", "float"}, + {"int", "integer"}, + {"float", "double"}, {"bool", "boolean"}, {"string", "str"}, {"object", "map"}, {"array", "list"}, } +var typeSubtypes = map[string]string{ + "uuid": "string", + "date": "string", + "datetime": "string", + "date-time": "string", + "uri": "string", + "email": "string", + "time": "string", + "timestamp": "string", + + "int": "number", + "integer": "number", + "int32": "int", + "int64": "int", + "float": "number", + "float64": "float", + "float32": "float", + "double": "number", +} + var typeCanonical map[string]string func init() { @@ -42,6 +64,43 @@ func canonicalType(t string) string { return lower } +func wireType(t string) string { + lower := strings.ToLower(strings.TrimSpace(t)) + if parent, ok := typeSubtypes[lower]; ok { + return wireType(parent) + } + return canonicalType(lower) +} + +func numericGroup(t string) string { + lower := strings.ToLower(strings.TrimSpace(t)) + if canon, ok := typeCanonical[lower]; ok { + if canon == "int" || canon == "float" { + return canon + } + } + if parent, ok := typeSubtypes[lower]; ok { + return numericGroup(parent) + } + return "" +} + +func areSubtypeCompatible(typeA, typeB string) bool { + a := strings.ToLower(strings.TrimSpace(typeA)) + b := strings.ToLower(strings.TrimSpace(typeB)) + groupA, groupB := numericGroup(a), numericGroup(b) + if groupA != "" && groupB != "" && groupA != groupB { + return false + } + + _, aIsSub := typeSubtypes[a] + _, bIsSub := typeSubtypes[b] + if !aIsSub && !bIsSub { + return false + } + return wireType(a) == wireType(b) +} + func normalizeFieldName(name string) string { name = strings.ReplaceAll(name, "-", "_") @@ -61,6 +120,15 @@ func normalizeFieldName(name string) string { return string(result) } +func sortedSources(m map[schemair.SchemaSource]fieldInfo) []schemair.SchemaSource { + keys := make([]schemair.SchemaSource, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { return string(keys[i]) < string(keys[j]) }) + return keys +} + func (e *Engine) Compare(endpoint, method string, schemas []schemair.SchemaIR) *Result { result := &Result{ Endpoint: endpoint, @@ -126,10 +194,14 @@ type fieldInfo struct { func collectFields(obj *schemair.ObjectSchema, prefix string, source schemair.SchemaSource, fieldPresence map[string]map[schemair.SchemaSource]fieldInfo, ) { - if obj == nil || obj.Fields == nil { + if obj == nil { return } + for fieldName, field := range obj.Fields { + if field == nil { + continue + } normalName := normalizeFieldName(fieldName) path := prefix + normalName @@ -155,19 +227,19 @@ func collectFields(obj *schemair.ObjectSchema, prefix string, source schemair.Sc func (e *Engine) compareRequests(schemas []schemair.SchemaIR) []Mismatch { fieldPresence := make(map[string]map[schemair.SchemaSource]fieldInfo) - sourcesWithRequest := 0 + var contributingSources []schemair.SchemaSource for _, schema := range schemas { - if schema.Request != nil && schema.Request.Fields != nil { - sourcesWithRequest++ + if schema.Request != nil && (schema.Request.Fields != nil || schema.Request.Items != nil) { + contributingSources = append(contributingSources, schema.Source) collectFields(schema.Request, "request.", schema.Source, fieldPresence) } } - if sourcesWithRequest < 2 { + if len(contributingSources) < 2 { return nil } - return e.detectMismatches(fieldPresence, schemas, "request") + return e.detectMismatches(fieldPresence, contributingSources, "request") } func (e *Engine) compareResponses(schemas []schemair.SchemaIR) []Mismatch { @@ -186,7 +258,7 @@ func (e *Engine) compareResponses(schemas []schemair.SchemaIR) []Mismatch { for code := range statusCodes { fieldPresence := make(map[string]map[schemair.SchemaSource]fieldInfo) - sourcesWithCode := 0 + var contributingSources []schemair.SchemaSource for _, schema := range schemas { if schema.Response == nil { continue @@ -195,38 +267,42 @@ func (e *Engine) compareResponses(schemas []schemair.SchemaIR) []Mismatch { if !ok || resp == nil { continue } - sourcesWithCode++ + contributingSources = append(contributingSources, schema.Source) prefix := fmt.Sprintf("response.%d.", code) collectFields(resp, prefix, schema.Source, fieldPresence) } - if sourcesWithCode < 2 { + if len(contributingSources) < 2 { continue } - mismatches = append(mismatches, e.detectMismatches(fieldPresence, schemas, "response")...) + mismatches = append(mismatches, e.detectMismatches(fieldPresence, contributingSources, "response")...) } return mismatches } -func (e *Engine) detectMismatches(fieldPresence map[string]map[schemair.SchemaSource]fieldInfo, schemas []schemair.SchemaIR, section string) []Mismatch { +func (e *Engine) detectMismatches(fieldPresence map[string]map[schemair.SchemaSource]fieldInfo, contributingSources []schemair.SchemaSource, section string) []Mismatch { var mismatches []Mismatch - allSources := make(map[schemair.SchemaSource]bool) - for _, s := range schemas { - allSources[s.Source] = true + allSources := make(map[schemair.SchemaSource]bool, len(contributingSources)) + for _, s := range contributingSources { + allSources[s] = true } - for path, sourceMap := range fieldPresence { - presentSources := make([]schemair.SchemaSource, 0, len(sourceMap)) - for src := range sourceMap { - presentSources = append(presentSources, src) - } + paths := make([]string, 0, len(fieldPresence)) + for p := range fieldPresence { + paths = append(paths, p) + } + sort.Strings(paths) + + for _, path := range paths { + sourceMap := fieldPresence[path] + presentSources := sortedSources(sourceMap) if len(sourceMap) < len(allSources) { missingSources := make([]schemair.SchemaSource, 0) - for source := range allSources { + for _, source := range contributingSources { if _, ok := sourceMap[source]; !ok { missingSources = append(missingSources, source) } @@ -245,58 +321,67 @@ func (e *Engine) detectMismatches(fieldPresence map[string]map[schemair.SchemaSo } if len(sourceMap) >= 2 { - var firstCanon string - var firstSource schemair.SchemaSource - isFirst := true - - for source, info := range sourceMap { - canon := canonicalType(info.typ) - if isFirst { - firstCanon = canon - firstSource = source - isFirst = false - continue - } + sources := sortedSources(sourceMap) + seen := make(map[string]bool) + for i := 0; i < len(sources); i++ { + for j := i + 1; j < len(sources); j++ { + srcA, srcB := sources[i], sources[j] + infoA, infoB := sourceMap[srcA], sourceMap[srcB] + canonA, canonB := canonicalType(infoA.typ), canonicalType(infoB.typ) + + if canonA == canonB { + continue + } + + pairKey := canonA + ":" + canonB + if canonA > canonB { + pairKey = canonB + ":" + canonA + } + if seen[pairKey] { + continue + } + seen[pairKey] = true + + severity := SeverityCritical + suggestion := fmt.Sprintf("Align type to '%s' across all sources", infoA.typ) + + if areSubtypeCompatible(infoA.typ, infoB.typ) { + severity = SeverityInfo + suggestion = fmt.Sprintf("'%s' and '%s' are wire-compatible (both serialize as %s in JSON)", infoA.typ, infoB.typ, wireType(infoA.typ)) + } - if canon != firstCanon { mismatches = append(mismatches, Mismatch{ Path: path, Type: MismatchTypeDiff, - Description: fmt.Sprintf("Type mismatch: %s has '%s', %s has '%s'", firstSource, sourceMap[firstSource].typ, source, info.typ), - Expected: sourceMap[firstSource].typ, - Actual: info.typ, + Description: fmt.Sprintf("Type mismatch: %s has '%s', %s has '%s'", srcA, infoA.typ, srcB, infoB.typ), + Expected: infoA.typ, + Actual: infoB.typ, InSources: presentSources, - Severity: SeverityCritical, - Suggestion: fmt.Sprintf("Align type to '%s' across all sources", sourceMap[firstSource].typ), + Severity: severity, + Suggestion: suggestion, }) } } } if len(sourceMap) >= 2 { - var firstReq bool - var firstSource schemair.SchemaSource - isFirst := true - - for source, info := range sourceMap { - if isFirst { - firstReq = info.required - firstSource = source - isFirst = false - continue - } + sources := sortedSources(sourceMap) + refSource := sources[0] + refReq := sourceMap[refSource].required - if info.required != firstReq { + for _, src := range sources[1:] { + if sourceMap[src].required != refReq { mismatches = append(mismatches, Mismatch{ Path: path, Type: MismatchOptionality, - Description: fmt.Sprintf("Optionality mismatch: %s=%v, %s=%v", firstSource, firstReq, source, info.required), - Expected: firstReq, - Actual: info.required, + Description: fmt.Sprintf("Optionality mismatch: %s=%v, %s=%v", refSource, refReq, src, sourceMap[src].required), + Expected: refReq, + Actual: sourceMap[src].required, InSources: presentSources, Severity: SeverityWarning, Suggestion: "Consider aligning optionality across sources", }) + break } } } @@ -364,9 +449,8 @@ func (e *Engine) hasViolations(mismatches []Mismatch) bool { func (e *Engine) calculateConfidence(schemas []schemair.SchemaIR, mismatches []Mismatch) *EndpointConfidence { conf := &EndpointConfidence{ - Score: 0, - Breakdown: make(map[string]float64), - Factors: []string{}, + Score: 0, + Factors: []string{}, } var score float64 = 0 diff --git a/cohesion_backend/pkg/diff/engine_test.go b/cohesion_backend/pkg/diff/engine_test.go index 6f29cca..4da17b5 100644 --- a/cohesion_backend/pkg/diff/engine_test.go +++ b/cohesion_backend/pkg/diff/engine_test.go @@ -75,4 +75,97 @@ func TestEngine_Compare(t *testing.T) { t.Errorf("Expected Critical severity, got %v", result.Mismatches[0].Severity) } }) + + t.Run("UUID vs String is Info not Critical", func(t *testing.T) { + schemas := []schemair.SchemaIR{ + { + Source: schemair.SourceBackendStatic, + Response: map[int]*schemair.ObjectSchema{ + 200: {Fields: map[string]*schemair.Field{ + "project_id": {Type: "uuid"}, + }}, + }, + }, + { + Source: schemair.SourceFrontendStatic, + Response: map[int]*schemair.ObjectSchema{ + 200: {Fields: map[string]*schemair.Field{ + "project_id": {Type: "string"}, + }}, + }, + }, + } + + result := engine.Compare("/api/test", "GET", schemas) + if result.Status == schemair.StatusViolation { + t.Errorf("UUID vs string should not be a Violation, got %v", result.Status) + } + if len(result.Mismatches) != 1 { + t.Fatalf("Expected 1 mismatch, got %d", len(result.Mismatches)) + } + if result.Mismatches[0].Severity != SeverityInfo { + t.Errorf("Expected Info severity for uuid vs string, got %v", result.Mismatches[0].Severity) + } + }) + + t.Run("Time vs String is Info not Critical", func(t *testing.T) { + schemas := []schemair.SchemaIR{ + { + Source: schemair.SourceBackendStatic, + Response: map[int]*schemair.ObjectSchema{ + 200: {Fields: map[string]*schemair.Field{ + "created_at": {Type: "time"}, + }}, + }, + }, + { + Source: schemair.SourceRuntime, + Response: map[int]*schemair.ObjectSchema{ + 200: {Fields: map[string]*schemair.Field{ + "created_at": {Type: "string"}, + }}, + }, + }, + } + + result := engine.Compare("/api/test", "GET", schemas) + if result.Status == schemair.StatusViolation { + t.Errorf("Time vs string should not be a Violation, got %v", result.Status) + } + if len(result.Mismatches) != 1 { + t.Fatalf("Expected 1 mismatch, got %d", len(result.Mismatches)) + } + if result.Mismatches[0].Severity != SeverityInfo { + t.Errorf("Expected Info severity for time vs string, got %v", result.Mismatches[0].Severity) + } + }) + + t.Run("Real type mismatch still Critical", func(t *testing.T) { + schemas := []schemair.SchemaIR{ + { + Source: schemair.SourceBackendStatic, + Response: map[int]*schemair.ObjectSchema{ + 200: {Fields: map[string]*schemair.Field{ + "count": {Type: "integer"}, + }}, + }, + }, + { + Source: schemair.SourceFrontendStatic, + Response: map[int]*schemair.ObjectSchema{ + 200: {Fields: map[string]*schemair.Field{ + "count": {Type: "string"}, + }}, + }, + }, + } + + result := engine.Compare("/api/test", "GET", schemas) + if result.Status != schemair.StatusViolation { + t.Errorf("Integer vs string should be a Violation, got %v", result.Status) + } + if result.Mismatches[0].Severity != SeverityCritical { + t.Errorf("Expected Critical severity for integer vs string, got %v", result.Mismatches[0].Severity) + } + }) } diff --git a/cohesion_backend/pkg/diff/types.go b/cohesion_backend/pkg/diff/types.go index 56261c1..f0da747 100644 --- a/cohesion_backend/pkg/diff/types.go +++ b/cohesion_backend/pkg/diff/types.go @@ -40,7 +40,6 @@ type Result struct { } type EndpointConfidence struct { - Score float64 `json:"score"` - Breakdown map[string]float64 `json:"breakdown"` - Factors []string `json:"factors"` + Score float64 `json:"score"` + Factors []string `json:"factors"` } diff --git a/cohesion_backend/pkg/github/appauth.go b/cohesion_backend/pkg/github/appauth.go new file mode 100644 index 0000000..5c575f7 --- /dev/null +++ b/cohesion_backend/pkg/github/appauth.go @@ -0,0 +1,40 @@ +package github + +import ( + "net/http" + + "github.com/bradleyfalzon/ghinstallation/v2" + gh "github.com/google/go-github/v68/github" +) + +type AppAuth struct { + appID int64 + privateKey []byte +} + +func NewAppAuth(appID int64, privateKey []byte) *AppAuth { + if appID == 0 || len(privateKey) == 0 { + return nil + } + return &AppAuth{appID: appID, privateKey: privateKey} +} + +func (a *AppAuth) IsConfigured() bool { + return a != nil +} + +func (a *AppAuth) InstallationClient(installationID int64) (*gh.Client, error) { + transport, err := ghinstallation.New(http.DefaultTransport, a.appID, installationID, a.privateKey) + if err != nil { + return nil, err + } + return gh.NewClient(&http.Client{Transport: transport}), nil +} + +func (a *AppAuth) AppClient() (*gh.Client, error) { + transport, err := ghinstallation.NewAppsTransport(http.DefaultTransport, a.appID, a.privateKey) + if err != nil { + return nil, err + } + return gh.NewClient(&http.Client{Transport: transport}), nil +} diff --git a/cohesion_backend/pkg/github/fetcher.go b/cohesion_backend/pkg/github/fetcher.go index 093e42f..2a78e9b 100644 --- a/cohesion_backend/pkg/github/fetcher.go +++ b/cohesion_backend/pkg/github/fetcher.go @@ -10,53 +10,12 @@ import ( gh "github.com/google/go-github/v68/github" "github.com/cohesion-api/cohesion_backend/pkg/analyzer/gemini" + "github.com/cohesion-api/cohesion_backend/pkg/sourcefile" ) -var skipDirs = map[string]bool{ - "vendor": true, ".git": true, "node_modules": true, "__pycache__": true, - ".venv": true, "venv": true, "dist": true, "build": true, "target": true, - ".idea": true, ".vscode": true, ".next": true, ".nuxt": true, -} - -var sourceExtensions = map[string]bool{ - ".go": true, ".py": true, ".ts": true, ".js": true, ".java": true, - ".rb": true, ".rs": true, ".php": true, ".cs": true, ".kt": true, - ".ex": true, ".exs": true, ".scala": true, ".swift": true, -} - -var languageHints = map[string]string{ - ".go": "Go", ".py": "Python", ".ts": "TypeScript", ".js": "JavaScript", - ".java": "Java", ".rb": "Ruby", ".rs": "Rust", ".php": "PHP", - ".cs": "C#", ".kt": "Kotlin", ".ex": "Elixir", ".exs": "Elixir", - ".scala": "Scala", ".swift": "Swift", -} - const maxTotalBytes = 900_000 * 4 const maxFileBytes = 100 * 1024 -func isTestFile(path string) bool { - lower := strings.ToLower(path) - return strings.HasSuffix(lower, "_test.go") || - strings.HasPrefix(filepath.Base(lower), "test_") || - strings.HasSuffix(lower, "_test.py") || - strings.HasSuffix(lower, ".test.ts") || - strings.HasSuffix(lower, ".test.js") || - strings.HasSuffix(lower, ".spec.ts") || - strings.HasSuffix(lower, ".spec.js") || - strings.Contains(lower, "/test/") || - strings.Contains(lower, "/tests/") || - strings.Contains(lower, "/__tests__/") -} - -func inSkippedDir(path string) bool { - for _, part := range strings.Split(path, "/") { - if skipDirs[part] { - return true - } - } - return false -} - func ParseRepoURL(input string) (owner, repo string, err error) { input = strings.TrimSpace(input) input = strings.TrimSuffix(input, ".git") @@ -85,7 +44,10 @@ func FetchRepoFiles(ctx context.Context, token, owner, repo, branch, subPath str if token != "" { client = client.WithAuthToken(token) } + return FetchRepoFilesWithClient(ctx, client, owner, repo, branch, subPath) +} +func FetchRepoFilesWithClient(ctx context.Context, client *gh.Client, owner, repo, branch, subPath string) ([]gemini.SourceFile, string, error) { if branch == "" { branch = "main" } @@ -122,16 +84,16 @@ func FetchRepoFiles(ctx context.Context, token, owner, repo, branch, subPath str continue } - if inSkippedDir(path) { + if sourcefile.InSkippedDir(path) { continue } ext := strings.ToLower(filepath.Ext(path)) - if !sourceExtensions[ext] { + if !sourcefile.SourceExtensions[ext] { continue } - if isTestFile(path) { + if sourcefile.IsTestFile(path) { continue } @@ -196,7 +158,7 @@ func FetchRepoFiles(ctx context.Context, token, owner, repo, branch, subPath str for ext, count := range extCount { if count > maxCount { maxCount = count - language = languageHints[ext] + language = sourcefile.LanguageHints[ext] } } diff --git a/cohesion_backend/pkg/sourcefile/sourcefile.go b/cohesion_backend/pkg/sourcefile/sourcefile.go new file mode 100644 index 0000000..97a54e7 --- /dev/null +++ b/cohesion_backend/pkg/sourcefile/sourcefile.go @@ -0,0 +1,79 @@ +package sourcefile + +import ( + "path/filepath" + "strings" +) + +var SkipDirs = map[string]bool{ + "vendor": true, + ".git": true, + "node_modules": true, + "__pycache__": true, + ".venv": true, + "venv": true, + "dist": true, + "build": true, + "target": true, + ".idea": true, + ".vscode": true, + ".next": true, + ".nuxt": true, +} + +var SourceExtensions = map[string]bool{ + ".go": true, + ".py": true, + ".ts": true, + ".js": true, + ".java": true, + ".rb": true, + ".rs": true, + ".php": true, + ".cs": true, + ".kt": true, + ".ex": true, + ".exs": true, + ".scala": true, + ".swift": true, +} + +var LanguageHints = map[string]string{ + ".go": "Go", + ".py": "Python", + ".ts": "TypeScript", + ".js": "JavaScript", + ".java": "Java", + ".rb": "Ruby", + ".rs": "Rust", + ".php": "PHP", + ".cs": "C#", + ".kt": "Kotlin", + ".ex": "Elixir", + ".exs": "Elixir", + ".scala": "Scala", + ".swift": "Swift", +} + +func IsTestFile(path string) bool { + lower := strings.ToLower(path) + return strings.HasSuffix(lower, "_test.go") || + strings.HasPrefix(filepath.Base(lower), "test_") || + strings.HasSuffix(lower, "_test.py") || + strings.HasSuffix(lower, ".test.ts") || + strings.HasSuffix(lower, ".test.js") || + strings.HasSuffix(lower, ".spec.ts") || + strings.HasSuffix(lower, ".spec.js") || + strings.Contains(lower, "/test/") || + strings.Contains(lower, "/tests/") || + strings.Contains(lower, "/__tests__/") +} + +func InSkippedDir(path string) bool { + for _, part := range strings.Split(path, "/") { + if SkipDirs[part] { + return true + } + } + return false +} diff --git a/cohesion_frontend/.env.example b/cohesion_frontend/.env.example index a9f31f1..c65f3ca 100644 --- a/cohesion_frontend/.env.example +++ b/cohesion_frontend/.env.example @@ -2,4 +2,5 @@ NEXT_PUBLIC_API_URL= NEXT_PUBLIC_CLERK_SIGN_IN_URL=/ NEXT_PUBLIC_CLERK_SIGN_UP_URL= NEXT_PUBLIC_CLERK_PUBLISHABLE_KEY= -CLERK_SECRET_KEY= \ No newline at end of file +CLERK_SECRET_KEY= +NEXT_PUBLIC_GITHUB_APP_SLUG=cohesion \ No newline at end of file diff --git a/cohesion_frontend/src/app/live/page.tsx b/cohesion_frontend/src/app/live/page.tsx index f201c66..a39d28a 100644 --- a/cohesion_frontend/src/app/live/page.tsx +++ b/cohesion_frontend/src/app/live/page.tsx @@ -14,6 +14,7 @@ import { Columns2, ArrowRightLeft, Radio, + Workflow, } from "lucide-react"; import { Header } from "@/components/layout/header"; import { Button } from "@/components/ui/button"; @@ -22,12 +23,13 @@ import { LiveOnboarding } from "@/components/live/live-onboarding"; import { SourceConfig } from "@/components/live/source-config"; import { DualTrafficView } from "@/components/live/dual-traffic-view"; import { LiveDiffView } from "@/components/live/live-diff-view"; +import { LiveHandshakeView } from "@/components/live/live-handshake-view"; import { useAppStore } from "@/stores/app-store"; import { api } from "@/lib/api"; import { LiveCapturedRequest } from "@/lib/types"; import { enableFrontendCapture, disableFrontendCapture } from "@/lib/live-capture"; -type ViewMode = "unified" | "dual" | "diff"; +type ViewMode = "unified" | "dual" | "diff" | "handshake"; interface ProxySource { label: string; @@ -39,6 +41,7 @@ const VIEW_TABS: { id: ViewMode; label: string; icon: typeof Radio }[] = [ { id: "unified", label: "Unified", icon: Radio }, { id: "dual", label: "Dual Sources", icon: Columns2 }, { id: "diff", label: "Live Diff", icon: ArrowRightLeft }, + { id: "handshake", label: "Live Handshake", icon: Workflow }, ]; export default function LivePage() { @@ -235,7 +238,7 @@ export default function LivePage() { if (selectedSourceB === label) setSelectedSourceB(""); }; - const isDualOrDiff = viewMode === "dual" || viewMode === "diff"; + const isDualOrDiff = viewMode === "dual" || viewMode === "diff" || viewMode === "handshake"; return (
+ No accounts connected yet. Click "Connect Account" to install the Cohesion GitHub App. +
+ )} + + {/* Divider before PAT */}
- Create a{" "}
-
- Personal Access Token
-
- {" "}with repo scope to scan private repositories.
-
+ Create a{" "}
+
+ Personal Access Token
+
+ {" "}with repo scope to scan private repositories.
+
+ Visualize the API handshake between sources +
++ {canCompute + ? "Click Compute Handshake to see how frontend and backend schemas align" + : "Capture traffic from both sources first"} +
++ Assign which source represents your frontend and + backend using the dropdowns above +
++ Inferring schemas and building handshake + visualization... +
++ No endpoints inferred yet +
++ Select an endpoint to view the handshake +
+Scan a GitHub repository to extract API schemas. Requires a{" "} - GitHub token in Settings. + GitHub App connection or token in Settings.