diff --git a/backend/server/src/index.ts b/backend/server/src/index.ts index 142f3d3..f214d7d 100644 --- a/backend/server/src/index.ts +++ b/backend/server/src/index.ts @@ -5,7 +5,6 @@ import express, { Express } from "express" import fs from "fs" import { createServer } from "http" import { Server } from "socket.io" -import { z } from "zod" import { AIWorker } from "./AIWorker" import { CONTAINER_TIMEOUT } from "./constants" import { DokkuClient } from "./DokkuClient" @@ -18,9 +17,10 @@ import { saveFileRL, } from "./ratelimit" import { SecureGitClient } from "./SecureGitClient" +import { socketAuth } from "./socketAuth"; // Import the new socketAuth middleware import { handleCloseTerminal, handleCreateFile, handleCreateFolder, handleCreateTerminal, handleDeleteFile, handleDeleteFolder, handleDeploy, handleGenerateCode, handleGetFile, handleGetFolder, handleHeartbeat, handleListApps, handleMoveFile, HandlerContext, handleRenameFile, handleResizeTerminal, handleSaveFile, handleTerminalData } from "./SocketHandlers" import { TerminalManager } from "./TerminalManager" -import { DokkuResponse, User } from "./types" +import { DokkuResponse } from "./types" import { LockManager } from "./utils" // Handle uncaught exceptions @@ -37,6 +37,17 @@ process.on("unhandledRejection", (reason, promise) => { // You can also handle the rejected promise here if needed }) +// Check if the sandbox owner is connected +function isOwnerConnected(sandboxId: string): boolean { + return (connections[sandboxId] ?? 0) > 0 +} + +// Initialize containers and managers +const containers: Record = {} +const connections: Record = {} +const fileManagers: Record = {} +const terminalManagers: Record = {} + // Load environment variables dotenv.config() @@ -51,76 +62,8 @@ const io = new Server(httpServer, { }, }) -// Check if the sandbox owner is connected -function isOwnerConnected(sandboxId: string): boolean { - return (connections[sandboxId] ?? 0) > 0 -} - -// Initialize containers and managers -const containers: Record = {} -const connections: Record = {} -const fileManagers: Record = {} -const terminalManagers: Record = {} - // Middleware for socket authentication -io.use(async (socket, next) => { - // Define the schema for handshake query validation - const handshakeSchema = z.object({ - userId: z.string(), - sandboxId: z.string(), - EIO: z.string(), - transport: z.string(), - }) - - const q = socket.handshake.query - const parseQuery = handshakeSchema.safeParse(q) - - // Check if the query is valid according to the schema - if (!parseQuery.success) { - next(new Error("Invalid request.")) - return - } - - const { sandboxId, userId } = parseQuery.data - // Fetch user data from the database - const dbUser = await fetch( - `${process.env.DATABASE_WORKER_URL}/api/user?id=${userId}`, - { - headers: { - Authorization: `${process.env.WORKERS_KEY}`, - }, - } - ) - const dbUserJSON = (await dbUser.json()) as User - - // Check if user data was retrieved successfully - if (!dbUserJSON) { - next(new Error("DB error.")) - return - } - - // Check if the user owns the sandbox or has shared access - const sandbox = dbUserJSON.sandbox.find((s) => s.id === sandboxId) - const sharedSandboxes = dbUserJSON.usersToSandboxes.find( - (uts) => uts.sandboxId === sandboxId - ) - - // If user doesn't own or have shared access to the sandbox, deny access - if (!sandbox && !sharedSandboxes) { - next(new Error("Invalid credentials.")) - return - } - - // Set socket data with user information - socket.data = { - userId, - sandboxId: sandboxId, - isOwner: sandbox !== undefined, - } - - // Allow the connection - next() -}) +io.use(socketAuth) // Use the new socketAuth middleware // Initialize lock manager const lockManager = new LockManager() diff --git a/backend/server/src/socketAuth.ts b/backend/server/src/socketAuth.ts new file mode 100644 index 0000000..3bd83b1 --- /dev/null +++ b/backend/server/src/socketAuth.ts @@ -0,0 +1,63 @@ +import { Socket } from "socket.io" +import { z } from "zod" +import { User } from "./types" + +// Middleware for socket authentication +export const socketAuth = async (socket: Socket, next: Function) => { + // Define the schema for handshake query validation + const handshakeSchema = z.object({ + userId: z.string(), + sandboxId: z.string(), + EIO: z.string(), + transport: z.string(), + }) + + const q = socket.handshake.query + const parseQuery = handshakeSchema.safeParse(q) + + // Check if the query is valid according to the schema + if (!parseQuery.success) { + next(new Error("Invalid request.")) + return + } + + const { sandboxId, userId } = parseQuery.data + // Fetch user data from the database + const dbUser = await fetch( + `${process.env.DATABASE_WORKER_URL}/api/user?id=${userId}`, + { + headers: { + Authorization: `${process.env.WORKERS_KEY}`, + }, + } + ) + const dbUserJSON = (await dbUser.json()) as User + + // Check if user data was retrieved successfully + if (!dbUserJSON) { + next(new Error("DB error.")) + return + } + + // Check if the user owns the sandbox or has shared access + const sandbox = dbUserJSON.sandbox.find((s) => s.id === sandboxId) + const sharedSandboxes = dbUserJSON.usersToSandboxes.find( + (uts) => uts.sandboxId === sandboxId + ) + + // If user doesn't own or have shared access to the sandbox, deny access + if (!sandbox && !sharedSandboxes) { + next(new Error("Invalid credentials.")) + return + } + + // Set socket data with user information + socket.data = { + userId, + sandboxId: sandboxId, + isOwner: sandbox !== undefined, + } + + // Allow the connection + next() +}