diff --git a/lib/cache.ts b/lib/cache.ts index d1a56b5e..3d09438b 100644 --- a/lib/cache.ts +++ b/lib/cache.ts @@ -1,4 +1,4 @@ -import { User, Website, Team } from '@prisma/client'; +import { User, Website, Team, TeamUser } from '@prisma/client'; import redis from '@umami/redis-client'; import { lightFormat, startOfMonth } from 'date-fns'; import { getAllWebsitesByUser, getSession, getUser, getViewTotals, getWebsite } from '../queries'; @@ -35,10 +35,10 @@ async function deleteObject(key, soft = false) { async function fetchWebsite(id): Promise< Website & { - team?: Team; + team?: Team & { teamUsers: TeamUser[] }; } > { - return fetchObject(`website:${id}`, () => getWebsite({ id })); + return fetchObject(`website:${id}`, () => getWebsite({ id }, true)); } async function storeWebsite(data) { diff --git a/lib/middleware.ts b/lib/middleware.ts index e8f56b1a..67ae71ad 100644 --- a/lib/middleware.ts +++ b/lib/middleware.ts @@ -1,20 +1,57 @@ -import { createMiddleware, unauthorized, badRequest, parseSecureToken } from 'next-basics'; -import debug from 'debug'; -import cors from 'cors'; -import { validate } from 'uuid'; import redis from '@umami/redis-client'; -import { findSession } from 'lib/session'; +import cors from 'cors'; +import debug from 'debug'; import { getAuthToken, parseShareToken } from 'lib/auth'; -import { secret } from 'lib/crypto'; import { ROLES } from 'lib/constants'; +import { secret } from 'lib/crypto'; +import { isOverApiLimit, findSession, findWebsite, useSessionCache } from 'lib/session'; +import { + badRequest, + createMiddleware, + parseSecureToken, + tooManyRequest, + unauthorized, +} from 'next-basics'; +import { NextApiRequestCollect } from 'pages/api/collect'; +import { validate } from 'uuid'; import { getUser } from '../queries'; +import { getJsonBody } from './detect'; const log = debug('umami:middleware'); export const useCors = createMiddleware(cors()); -export const useSession = createMiddleware(async (req: any, res, next) => { - const session = await findSession(req); +export const useSession = createMiddleware(async (req: NextApiRequestCollect, res, next) => { + // Verify payload + const { payload } = getJsonBody(req); + + const { website: websiteId } = payload; + + if (!payload) { + log('useSession: No payload'); + return badRequest(res); + } + + // Get session from cache + let session = await useSessionCache(req); + + // Get or create session + if (!session) { + const website = await findWebsite(websiteId); + + if (!website) { + log('useSession: Website not found'); + return badRequest(res); + } + + if (process.env.ENABLE_COLLECT_LIMIT) { + if (isOverApiLimit(website)) { + return tooManyRequest(res, 'Collect currently exceeds monthly limit of 10000.'); + } + } + + session = await findSession(req, payload); + } if (!session) { log('useSession: Session not found'); diff --git a/lib/session.ts b/lib/session.ts index 08a54fc8..a2e87ab6 100644 --- a/lib/session.ts +++ b/lib/session.ts @@ -2,84 +2,41 @@ import { Session, Team, Website } from '@prisma/client'; import cache from 'lib/cache'; import clickhouse from 'lib/clickhouse'; import { secret, uuid } from 'lib/crypto'; -import { getClientInfo, getJsonBody } from 'lib/detect'; +import { getClientInfo } from 'lib/detect'; import { parseToken } from 'next-basics'; import { createSession, getSession, getWebsite } from 'queries'; -import { validate } from 'uuid'; -export async function findSession(req): Promise<{ - error?: { - status: number; - message: string; - }; - session?: { - id: string; - websiteId: string; - hostname: string; - browser: string; - os: string; - device: string; - screen: string; - language: string; - country: string; - }; - website?: Website & { team?: Team }; +export async function findSession( + req, + payload, +): Promise<{ + id: string; + websiteId: string; + hostname: string; + browser: string; + os: string; + device: string; + screen: string; + language: string; + country: string; }> { - const { payload } = getJsonBody(req); - - if (!payload) { - return null; - } - - // Verify payload const { website: websiteId, hostname, screen, language } = payload; - // Find website - let website: Website & { team?: Team } = null; - - if (cache.enabled) { - website = await cache.fetchWebsite(websiteId); - } else { - website = await getWebsite({ id: websiteId }); - } - - if (!website || website.deletedAt) { - throw new Error(`Website not found: ${websiteId}`); - } - - // Check if cache token is passed - const cacheToken = req.headers['x-umami-cache']; - - if (cacheToken) { - const result = await parseToken(cacheToken, secret()); - - if (result) { - return { session: result, website }; - } - } - - if (!validate(websiteId)) { - return null; - } - const { userAgent, browser, os, ip, country, device } = await getClientInfo(req, payload); const sessionId = uuid(websiteId, hostname, ip, userAgent); // Clickhouse does not require session lookup if (clickhouse.enabled) { return { - session: { - id: sessionId, - websiteId, - hostname, - browser, - os, - device, - screen, - language, - country, - }, - website, + id: sessionId, + websiteId, + hostname, + browser, + os, + device, + screen, + language, + country, }; } @@ -113,5 +70,61 @@ export async function findSession(req): Promise<{ } } - return { session, website }; + return session; +} + +export async function useSessionCache(req: any): Promise<{ + id: string; + websiteId: string; + hostname: string; + browser: string; + os: string; + device: string; + screen: string; + language: string; + country: string; +}> { + // Check if cache token is passed + const cacheToken = req.headers['x-umami-cache']; + + if (cacheToken) { + const result = await parseToken(cacheToken, secret()); + + if (result) { + return result; + } + } + + return null; +} + +export async function findWebsite(websiteId: string) { + let website: Website & { team?: Team } = null; + + if (cache.enabled) { + website = await cache.fetchWebsite(websiteId); + } else { + website = await getWebsite({ id: websiteId }, true); + } + + if (!website || website.deletedAt) { + throw new Error(`Website not found: ${websiteId}`); + } + + return website; +} + +export async function isOverApiLimit(website) { + const userId = website.userId ? website.userId : website.team.teamu.userId; + + const limit = await cache.fetchCollectLimit(userId); + + // To-do: Need to implement logic to find user-specific limit. Defaulted to 10k. + if (limit > 10000) { + return true; + } + + await cache.incrementCollectLimit(userId); + + return false; } diff --git a/pages/api/collect.ts b/pages/api/collect.ts index 2f70bc08..67714fed 100644 --- a/pages/api/collect.ts +++ b/pages/api/collect.ts @@ -1,40 +1,24 @@ const { Resolver } = require('dns').promises; -import isbot from 'isbot'; import ipaddr from 'ipaddr.js'; -import { - createToken, - unauthorized, - send, - badRequest, - forbidden, - tooManyRequest, -} from 'next-basics'; -import { savePageView, saveEvent } from 'queries'; -import { useCors, useSession } from 'lib/middleware'; -import { getJsonBody, getIpAddress } from 'lib/detect'; +import isbot from 'isbot'; import { secret } from 'lib/crypto'; +import { getIpAddress, getJsonBody } from 'lib/detect'; +import { useCors, useSession } from 'lib/middleware'; import { NextApiRequest, NextApiResponse } from 'next'; -import cache from 'lib/cache'; -import { Team, Website } from '@prisma/client'; +import { badRequest, createToken, forbidden, send, unauthorized } from 'next-basics'; +import { saveEvent, savePageView } from 'queries'; export interface NextApiRequestCollect extends NextApiRequest { - session: { - error?: { - status: number; - message: string; - }; - session?: { - id: string; - websiteId: string; - hostname: string; - browser: string; - os: string; - device: string; - screen: string; - language: string; - country: string; - }; - website?: Website & { team?: Team }; + session?: { + id: string; + websiteId: string; + hostname: string; + browser: string; + os: string; + device: string; + screen: string; + language: string; + country: string; }; } @@ -104,21 +88,7 @@ export default async (req: NextApiRequestCollect, res: NextApiResponse) => { await useSession(req, res); - const { session, website } = req.session; - - // Check collection limit - if (process.env.ENABLE_COLLECT_LIMIT) { - const userId = website.userId ? website.userId : website.team.userId; - - const limit = await cache.fetchCollectLimit(userId); - - // To-do: Need to implement logic to find user-specific limit. Defaulted to 10k. - if (limit > 10000) { - return tooManyRequest(res, 'Collect currently exceeds monthly limit of 10000.'); - } - - await cache.incrementCollectLimit(userId); - } + const session = req.session; if (process.env.REMOVE_TRAILING_SLASH) { url = url.replace(/\/$/, ''); diff --git a/queries/admin/website.ts b/queries/admin/website.ts index e820bc84..74f208ef 100644 --- a/queries/admin/website.ts +++ b/queries/admin/website.ts @@ -1,18 +1,29 @@ -import { Prisma, Team, Website } from '@prisma/client'; +import { Prisma, Team, TeamUser, Website } from '@prisma/client'; import cache from 'lib/cache'; import prisma from 'lib/prisma'; import { runQuery, CLICKHOUSE, PRISMA } from 'lib/db'; -export async function getWebsite(where: Prisma.WebsiteWhereUniqueInput): Promise< +export async function getWebsite( + where: Prisma.WebsiteWhereUniqueInput, + includeTeamData = false, +): Promise< Website & { - team?: Team; + team?: Team & { teamUsers: TeamUser[] }; } > { + prisma.client.team.findMany(); + return prisma.client.website.findUnique({ where, - include: { - team: true, - }, + include: includeTeamData + ? { + team: { + include: { + teamUsers: true, + }, + }, + } + : {}, }); }