Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { NextRequest } from "next/server";

import NileAuth from "@nile-auth/core";
import { EventEnum, Logger, ResponseLogger } from "@nile-auth/logger";
import { getOrigin, getSecureCookies } from "@nile-auth/core/cookies";
import { getOrigin } from "@nile-auth/core/cookies";

const log = Logger(EventEnum.NILE_AUTH);

Expand Down
1 change: 0 additions & 1 deletion apps/server/app/v2/databases/[database]/auth/mfa/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,6 @@ export async function PUT(req: NextRequest) {
`;

if (sessionError) {
console.log(sessionError);
return sessionError;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jest.mock("@nile-auth/logger", () => ({
(body, { status = 200, headers = {} }) =>
new Response(body, { status, headers }),
),
{ error: (e: Error) => console.log(e) },
{ error: (e: Error) => jest.fn },
],
}));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ describe("accept invite", () => {

values.forEach((val, idx) => {
const normalized =
val && typeof val === "object" && "value" in (val as Record<string, any>)
val &&
typeof val === "object" &&
"value" in (val as Record<string, any>)
? (val as { value: string }).value
: val;
text = text.replace(`$${idx + 1}`, normalized as string);
Expand Down Expand Up @@ -129,7 +131,9 @@ describe("accept invite", () => {
}
values.forEach((val, i) => {
const normalized =
val && typeof val === "object" && "value" in (val as Record<string, any>)
val &&
typeof val === "object" &&
"value" in (val as Record<string, any>)
? (val as { value: string }).value
: val;
text = text.replace(`$${i + 1}`, normalized as string);
Expand All @@ -141,7 +145,6 @@ describe("accept invite", () => {
commands.push(text);

if (text.includes("DELETE")) {
console.log(text, "not here?");
return [null, { rowCount: 1 }];
}

Expand Down
70 changes: 57 additions & 13 deletions apps/server/test/integration.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ const primaryUser = {
email: "delete@me.com",
password: "deleteme",
};
const staleJwtUser = {
email: "delete4@me.com",
password: "deleteme",
};
const newUser = {
email: "delete2@me.com",
password: "deleteme",
Expand Down Expand Up @@ -165,23 +169,63 @@ describe("api integration", () => {
await nile.db.query("delete from tenants where id = $1", [newTenant.id]);
await nile.clearConnections();
}, 10000);

test("revoked JWT cannot be reused after sign out", async () => {
const nile = new Server(config);
await initialDebugCleanup(nile);

const user = (await nile.api.users.createUser(
staleJwtUser,
)) as unknown as { id: string };
expect(user.id).toBeTruthy();

await nile.api.login(staleJwtUser, { returnResponse: true });

const activeMe = await nile.api.users.me<{ email: string }>();
expect(activeMe.email).toEqual(staleJwtUser.email);

// capture the auth cookies before revocation
const staleHeaders = new Headers(nile.api.headers);

const signOutRes = (await nile.api.auth.signOut({
callbackUrl: "http://localhost:3000",
})) as Response | { url: string };
if (signOutRes instanceof Response) {
expect(signOutRes.status).toEqual(200);
} else {
expect(signOutRes.url).toEqual("http://localhost:3000");
}

const staleMe = (await nile.api.users.me<Response>(
staleHeaders,
)) as Response;
expect(staleMe.status).toEqual(401);

await nile.db.query("delete from auth.credentials where user_id = $1", [
user.id,
]);
await nile.db.query("delete from users.users where id = $1", [user.id]);
await nile.clearConnections();
}, 10000);
});

async function initialDebugCleanup(nile: Server) {
// remove the users 1st, fk constraints
const existing = [primaryUser, newUser, tenantUser].map(async (u) => {
const exists = await nile.db.query(
"select * from users.users where email = $1",
[u.email],
);
if (exists.rows.length > 0) {
const id = exists.rows[0].id;
await nile.db.query("delete from auth.credentials where user_id = $1", [
id,
]);
await nile.db.query("delete from users.users where id= $1", [id]);
}
});
const existing = [primaryUser, newUser, tenantUser, staleJwtUser].map(
async (u) => {
const exists = await nile.db.query(
"select * from users.users where email = $1",
[u.email],
);
if (exists.rows.length > 0) {
const id = exists.rows[0].id;
await nile.db.query("delete from auth.credentials where user_id = $1", [
id,
]);
await nile.db.query("delete from users.users where id= $1", [id]);
}
},
);
await Promise.all(existing);
const tenants = await nile.db.query("select * from tenants;");
const commands = tenants.rows.reduce((accum, t) => {
Expand Down
36 changes: 35 additions & 1 deletion packages/core/src/auth.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { getToken, JWT } from "next-auth/jwt";
import { Logger } from "@nile-auth/logger";
import { Pool } from "pg";
import getDbInfo from "@nile-auth/query/getDbInfo";
import { query } from "@nile-auth/query";
import { getSecureCookies } from "./next-auth/cookies";

type SessionUser = { user?: { id?: string } };
Expand Down Expand Up @@ -31,7 +34,38 @@ export async function buildFetch(
!isNaN(token.exp) &&
token.exp > now
) {
return [{ user: { id: String(token.id) } }];
try {
if (typeof token.jti !== "string") {
throw new Error("JWT missing jti");
}
const dbInfo = getDbInfo(undefined, req);
const pool = new Pool(dbInfo);
const sql = await query(pool);
const sessions = await sql`
SELECT
expires_at
FROM
auth.sessions
WHERE
session_token = ${token.jti}
`;
if (
sessions &&
"rowCount" in sessions &&
sessions.rowCount > 0 &&
sessions.rows[0]?.expires_at &&
new Date(sessions.rows[0].expires_at).getTime() > Date.now()
) {
return [{ user: { id: String(token.id) } }];
}
} catch (e) {
if (e instanceof Error) {
warn("revocation check failed", {
message: e.message,
stack: e.stack,
});
}
}
}
}
const url = new URL(req.url);
Expand Down
141 changes: 141 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
jest.mock("next-auth", () => ({
__esModule: true,
default: jest.fn(),
}));

jest.mock("./nextOptions", () => ({
nextOptions: jest.fn(),
}));

jest.mock("./utils", () => ({
buildOptions: jest.fn(),
}));

jest.mock("@nile-auth/query/getDbInfo", () => ({
__esModule: true,
default: jest.fn(),
}));

jest.mock("./next-auth/cookies", () => {
const actual = jest.requireActual("./next-auth/cookies");
return {
...actual,
getOrigin: jest.fn(),
getTenantCookie: jest.fn(),
setTenantCookie: jest.fn(),
};
});

jest.mock("@nile-auth/query", () => ({
queryByReq: jest.fn(),
}));

jest.mock("./mfa/providerResponse", () => ({
buildProviderMfaResponse: jest.fn(),
}));

jest.mock("./next-auth/providers/email", () => ({
sendVerifyEmail: jest.fn(),
}));

import NileAuth from "./index";
import NextAuth from "next-auth";
import { nextOptions } from "./nextOptions";
import { buildOptions } from "./utils";
import getDbInfo from "@nile-auth/query/getDbInfo";
import {
getTenantCookie,
getOrigin,
setTenantCookie,
} from "./next-auth/cookies";
import { queryByReq } from "@nile-auth/query";
import { buildProviderMfaResponse } from "./mfa/providerResponse";

const nextAuthMock = NextAuth as jest.MockedFunction<typeof NextAuth>;
const nextOptionsMock = nextOptions as jest.MockedFunction<typeof nextOptions>;
const buildOptionsMock = buildOptions as jest.MockedFunction<
typeof buildOptions
>;
const getDbInfoMock = getDbInfo as jest.MockedFunction<typeof getDbInfo>;
const getTenantCookieMock = getTenantCookie as jest.MockedFunction<
typeof getTenantCookie
>;
const getOriginMock = getOrigin as jest.MockedFunction<typeof getOrigin>;
const setTenantCookieMock = setTenantCookie as jest.MockedFunction<
typeof setTenantCookie
>;
const queryByReqMock = queryByReq as jest.MockedFunction<typeof queryByReq>;
const buildProviderMfaResponseMock =
buildProviderMfaResponse as jest.MockedFunction<
typeof buildProviderMfaResponse
>;

describe("NileAuth", () => {
const dbInfo = {
host: "localhost",
database: "nile",
user: "nile",
password: "secret",
port: 5432,
};

beforeEach(() => {
jest.clearAllMocks();
getDbInfoMock.mockReturnValue(dbInfo as any);
nextOptionsMock.mockResolvedValue([{ providers: [{}] } as any]);
buildOptionsMock.mockReturnValue({} as any);
getOriginMock.mockReturnValue("https://example.com");
getTenantCookieMock.mockReturnValue(null);
buildProviderMfaResponseMock.mockResolvedValue(null);
nextAuthMock.mockResolvedValue(new Response(null, { status: 200 }));
setTenantCookieMock.mockReturnValue(
new Headers([["set-cookie", "tenant=new"]]),
);
});

it("appends tenant cookie header after a successful credentials callback", async () => {
const executedQueries: string[] = [];
queryByReqMock.mockResolvedValueOnce(async function sql(
strings: TemplateStringsArray,
...values: unknown[]
): Promise<any> {
let text = strings[0] ?? "";
for (let i = 1; i < strings.length; i++) {
text += `$${i}${strings[i] ?? ""}`;
}
values.forEach((val, idx) => {
text = text.replace(`$${idx + 1}`, String(val));
});
text = text.replace(/\n\s+/g, " ").trim();
executedQueries.push(text);
return [
{
rowCount: 1,
rows: [{ id: "tenant-123", name: "Tenant 123" }],
},
];
});

const form = new FormData();
form.set("email", "unit@example.com");
const req = new Request(
"https://example.com/api/auth/[...nextauth]/callback/credentials",
{
method: "POST",
body: form,
},
);

const res = await NileAuth(req, {
params: { nextauth: ["callback", "credentials"] },
});

expect(setTenantCookieMock).toHaveBeenCalledWith(req, [
{ id: "tenant-123", name: "Tenant 123" },
]);
expect(res.headers.get("set-cookie")).toEqual("tenant=new");
expect(executedQueries).toEqual([
"SELECT DISTINCT t.id, t.name FROM public.tenants t JOIN users.tenant_users tu ON t.id = tu.tenant_id JOIN users.users u ON u.id = tu.user_id WHERE LOWER(u.email) = LOWER(unit@example.com) AND tu.deleted IS NULL AND t.deleted IS NULL AND u.deleted IS NULL",
]);
});
});
Loading
Loading