From ce5e55bdf634457b84d1138a4b4f573a28cadecf Mon Sep 17 00:00:00 2001 From: Ishaan Dey Date: Mon, 13 May 2024 23:22:06 -0700 Subject: [PATCH] improve ai logic --- backend/server/src/index.ts | 432 ++++++++++++------------ backend/server/src/utils.ts | 22 -- frontend/components/editor/generate.tsx | 4 +- 3 files changed, 224 insertions(+), 234 deletions(-) diff --git a/backend/server/src/index.ts b/backend/server/src/index.ts index da57179..7eaefd6 100644 --- a/backend/server/src/index.ts +++ b/backend/server/src/index.ts @@ -1,26 +1,25 @@ -import fs from "fs" -import os from "os" -import path from "path" -import cors from "cors" -import express, { Express } from "express" -import dotenv from "dotenv" -import { createServer } from "http" -import { Server } from "socket.io" +import fs from "fs"; +import os from "os"; +import path from "path"; +import cors from "cors"; +import express, { Express } from "express"; +import dotenv from "dotenv"; +import { createServer } from "http"; +import { Server } from "socket.io"; -import { z } from "zod" -import { User } from "./types" +import { z } from "zod"; +import { User } from "./types"; import { createFile, deleteFile, - generateCode, getFolder, getProjectSize, getSandboxFiles, renameFile, saveFile, stopServer, -} from "./utils" -import { IDisposable, IPty, spawn } from "node-pty" +} from "./utils"; +import { IDisposable, IPty, spawn } from "node-pty"; import { MAX_BODY_SIZE, createFileRL, @@ -28,356 +27,359 @@ import { deleteFileRL, renameFileRL, saveFileRL, -} from "./ratelimit" +} from "./ratelimit"; -dotenv.config() +dotenv.config(); -const app: Express = express() -const port = process.env.PORT || 4000 -app.use(cors()) -const httpServer = createServer(app) +const app: Express = express(); +const port = process.env.PORT || 4000; +app.use(cors()); +const httpServer = createServer(app); const io = new Server(httpServer, { cors: { origin: "*", }, -}) +}); let inactivityTimeout: NodeJS.Timeout | null = null; let isOwnerConnected = false; const terminals: { - [id: string]: { terminal: IPty; onData: IDisposable; onExit: IDisposable } -} = {} + [id: string]: { terminal: IPty; onData: IDisposable; onExit: IDisposable }; +} = {}; -const dirName = path.join(__dirname, "..") +const dirName = path.join(__dirname, ".."); const handshakeSchema = z.object({ userId: z.string(), sandboxId: z.string(), EIO: z.string(), transport: z.string(), -}) +}); io.use(async (socket, next) => { - const q = socket.handshake.query - const parseQuery = handshakeSchema.safeParse(q) + const q = socket.handshake.query; + const parseQuery = handshakeSchema.safeParse(q); if (!parseQuery.success) { - ("Invalid request.") - next(new Error("Invalid request.")) - return + ("Invalid request."); + next(new Error("Invalid request.")); + return; } - const { sandboxId, userId } = parseQuery.data - const dbUser = await fetch(`https://database.ishaan1013.workers.dev/api/user?id=${userId}`) - const dbUserJSON = (await dbUser.json()) as User + const { sandboxId, userId } = parseQuery.data; + const dbUser = await fetch( + `https://database.ishaan1013.workers.dev/api/user?id=${userId}` + ); + const dbUserJSON = (await dbUser.json()) as User; if (!dbUserJSON) { - next(new Error("DB error.")) - return + next(new Error("DB error.")); + return; } - const sandbox = dbUserJSON.sandbox.find((s) => s.id === sandboxId) + const sandbox = dbUserJSON.sandbox.find((s) => s.id === sandboxId); const sharedSandboxes = dbUserJSON.usersToSandboxes.find( (uts) => uts.sandboxId === sandboxId - ) + ); if (!sandbox && !sharedSandboxes) { - next(new Error("Invalid credentials.")) - return + next(new Error("Invalid credentials.")); + return; } socket.data = { userId, sandboxId: sandboxId, isOwner: sandbox !== undefined, - } + }; - next() -}) + next(); +}); io.on("connection", async (socket) => { - if (inactivityTimeout) clearTimeout(inactivityTimeout); const data = socket.data as { - userId: string - sandboxId: string - isOwner: boolean - } + userId: string; + sandboxId: string; + isOwner: boolean; + }; if (data.isOwner) { - isOwnerConnected = true + isOwnerConnected = true; } else { if (!isOwnerConnected) { - socket.emit("disableAccess", "The sandbox owner is not connected.") - return + socket.emit("disableAccess", "The sandbox owner is not connected."); + return; } } - const sandboxFiles = await getSandboxFiles(data.sandboxId) + const sandboxFiles = await getSandboxFiles(data.sandboxId); sandboxFiles.fileData.forEach((file) => { - const filePath = path.join(dirName, file.id) - fs.mkdirSync(path.dirname(filePath), { recursive: true }) + const filePath = path.join(dirName, file.id); + fs.mkdirSync(path.dirname(filePath), { recursive: true }); fs.writeFile(filePath, file.data, function (err) { - if (err) throw err - }) - }) + if (err) throw err; + }); + }); - socket.emit("loaded", sandboxFiles.files) + socket.emit("loaded", sandboxFiles.files); socket.on("getFile", (fileId: string, callback) => { - const file = sandboxFiles.fileData.find((f) => f.id === fileId) - if (!file) return + const file = sandboxFiles.fileData.find((f) => f.id === fileId); + if (!file) return; - callback(file.data) - }) + callback(file.data); + }); socket.on("getFolder", async (folderId: string, callback) => { - const files = await getFolder(folderId) - callback(files) - }) + const files = await getFolder(folderId); + callback(files); + }); // todo: send diffs + debounce for efficiency socket.on("saveFile", async (fileId: string, body: string) => { try { - await saveFileRL.consume(data.userId, 1) + await saveFileRL.consume(data.userId, 1); if (Buffer.byteLength(body, "utf-8") > MAX_BODY_SIZE) { socket.emit( "rateLimit", "Rate limited: file size too large. Please reduce the file size." - ) - return + ); + return; } - const file = sandboxFiles.fileData.find((f) => f.id === fileId) - if (!file) return - file.data = body + const file = sandboxFiles.fileData.find((f) => f.id === fileId); + if (!file) return; + file.data = body; fs.writeFile(path.join(dirName, file.id), body, function (err) { - if (err) throw err - }) - await saveFile(fileId, body) + if (err) throw err; + }); + await saveFile(fileId, body); } catch (e) { - io.emit("rateLimit", "Rate limited: file saving. Please slow down.") + io.emit("rateLimit", "Rate limited: file saving. Please slow down."); } - }) + }); socket.on("moveFile", async (fileId: string, folderId: string, callback) => { - const file = sandboxFiles.fileData.find((f) => f.id === fileId) - if (!file) return + const file = sandboxFiles.fileData.find((f) => f.id === fileId); + if (!file) return; - const parts = fileId.split("/") - const newFileId = folderId + "/" + parts.pop() + const parts = fileId.split("/"); + const newFileId = folderId + "/" + parts.pop(); fs.rename( path.join(dirName, fileId), path.join(dirName, newFileId), function (err) { - if (err) throw err + if (err) throw err; } - ) + ); - file.id = newFileId + file.id = newFileId; - await renameFile(fileId, newFileId, file.data) - const newFiles = await getSandboxFiles(data.sandboxId) + await renameFile(fileId, newFileId, file.data); + const newFiles = await getSandboxFiles(data.sandboxId); - callback(newFiles.files) - }) + callback(newFiles.files); + }); socket.on("createFile", async (name: string, callback) => { try { - - const size: number = await getProjectSize(data.sandboxId) + const size: number = await getProjectSize(data.sandboxId); // limit is 200mb if (size > 200 * 1024 * 1024) { - io.emit("rateLimit", "Rate limited: project size exceeded. Please delete some files.") - callback({success: false}) + io.emit( + "rateLimit", + "Rate limited: project size exceeded. Please delete some files." + ); + callback({ success: false }); } - await createFileRL.consume(data.userId, 1) + await createFileRL.consume(data.userId, 1); - const id = `projects/${data.sandboxId}/${name}` + const id = `projects/${data.sandboxId}/${name}`; fs.writeFile(path.join(dirName, id), "", function (err) { - if (err) throw err - }) + if (err) throw err; + }); sandboxFiles.files.push({ id, name, type: "file", - }) + }); sandboxFiles.fileData.push({ id, data: "", - }) + }); - await createFile(id) + await createFile(id); - callback({success: true}) + callback({ success: true }); } catch (e) { - io.emit("rateLimit", "Rate limited: file creation. Please slow down.") + io.emit("rateLimit", "Rate limited: file creation. Please slow down."); } - }) + }); socket.on("createFolder", async (name: string, callback) => { try { - await createFolderRL.consume(data.userId, 1) + await createFolderRL.consume(data.userId, 1); - const id = `projects/${data.sandboxId}/${name}` + const id = `projects/${data.sandboxId}/${name}`; fs.mkdir(path.join(dirName, id), { recursive: true }, function (err) { - if (err) throw err - }) + if (err) throw err; + }); - callback() + callback(); } catch (e) { - io.emit("rateLimit", "Rate limited: folder creation. Please slow down.") + io.emit("rateLimit", "Rate limited: folder creation. Please slow down."); } - }) + }); socket.on("renameFile", async (fileId: string, newName: string) => { try { - await renameFileRL.consume(data.userId, 1) + await renameFileRL.consume(data.userId, 1); - const file = sandboxFiles.fileData.find((f) => f.id === fileId) - if (!file) return - file.id = newName + const file = sandboxFiles.fileData.find((f) => f.id === fileId); + if (!file) return; + file.id = newName; - const parts = fileId.split("/") + const parts = fileId.split("/"); const newFileId = - parts.slice(0, parts.length - 1).join("/") + "/" + newName + parts.slice(0, parts.length - 1).join("/") + "/" + newName; fs.rename( path.join(dirName, fileId), path.join(dirName, newFileId), function (err) { - if (err) throw err + if (err) throw err; } - ) - await renameFile(fileId, newFileId, file.data) + ); + await renameFile(fileId, newFileId, file.data); } catch (e) { - io.emit("rateLimit", "Rate limited: file renaming. Please slow down.") - return + io.emit("rateLimit", "Rate limited: file renaming. Please slow down."); + return; } - }) + }); socket.on("deleteFile", async (fileId: string, callback) => { try { - await deleteFileRL.consume(data.userId, 1) - const file = sandboxFiles.fileData.find((f) => f.id === fileId) - if (!file) return + await deleteFileRL.consume(data.userId, 1); + const file = sandboxFiles.fileData.find((f) => f.id === fileId); + if (!file) return; fs.unlink(path.join(dirName, fileId), function (err) { - if (err) throw err - }) + if (err) throw err; + }); sandboxFiles.fileData = sandboxFiles.fileData.filter( (f) => f.id !== fileId - ) + ); - await deleteFile(fileId) + await deleteFile(fileId); - const newFiles = await getSandboxFiles(data.sandboxId) - callback(newFiles.files) + const newFiles = await getSandboxFiles(data.sandboxId); + callback(newFiles.files); } catch (e) { - io.emit("rateLimit", "Rate limited: file deletion. Please slow down.") + io.emit("rateLimit", "Rate limited: file deletion. Please slow down."); } - }) + }); socket.on("renameFolder", async (folderId: string, newName: string) => { - // todo - }) + // todo + }); socket.on("deleteFolder", async (folderId: string, callback) => { - const files = await getFolder(folderId) + const files = await getFolder(folderId); - await Promise.all(files.map(async (file) => { - fs.unlink(path.join(dirName, file), function (err) { - if (err) throw err + await Promise.all( + files.map(async (file) => { + fs.unlink(path.join(dirName, file), function (err) { + if (err) throw err; + }); + + sandboxFiles.fileData = sandboxFiles.fileData.filter( + (f) => f.id !== file + ); + + await deleteFile(file); }) + ); - sandboxFiles.fileData = sandboxFiles.fileData.filter( - (f) => f.id !== file - ) + const newFiles = await getSandboxFiles(data.sandboxId); - await deleteFile(file) - })) - - const newFiles = await getSandboxFiles(data.sandboxId) - - callback(newFiles.files) - - }) + callback(newFiles.files); + }); socket.on("createTerminal", (id: string, callback) => { - console.log("creating terminal", id) + console.log("creating terminal", id); if (terminals[id] || Object.keys(terminals).length >= 4) { - return + return; } const pty = spawn(os.platform() === "win32" ? "cmd.exe" : "bash", [], { name: "xterm", cols: 100, cwd: path.join(dirName, "projects", data.sandboxId), - }) + }); const onData = pty.onData((data) => { // console.log("terminalResponse", id, data) io.emit("terminalResponse", { id, data, - }) - }) + }); + }); - const onExit = pty.onExit((code) => console.log("exit :(", code)) + const onExit = pty.onExit((code) => console.log("exit :(", code)); - pty.write("clear\r") + pty.write("clear\r"); terminals[id] = { terminal: pty, onData, onExit, - } + }; - callback() - }) + callback(); + }); socket.on("resizeTerminal", (dimensions: { cols: number; rows: number }) => { - console.log("resizeTerminal", dimensions) + console.log("resizeTerminal", dimensions); Object.values(terminals).forEach((t) => { - t.terminal.resize(dimensions.cols, dimensions.rows) - }) - - }) + t.terminal.resize(dimensions.cols, dimensions.rows); + }); + }); socket.on("terminalData", (id: string, data: string) => { if (!terminals[id]) { - console.log("terminal not found", id) - return + console.log("terminal not found", id); + return; } try { - terminals[id].terminal.write(data) + terminals[id].terminal.write(data); } catch (e) { - console.log("Error writing to terminal", e) + console.log("Error writing to terminal", e); } - }) + }); socket.on("closeTerminal", (id: string, callback) => { if (!terminals[id]) { - return + return; } - terminals[id].onData.dispose() - terminals[id].onExit.dispose() - delete terminals[id] + terminals[id].onData.dispose(); + terminals[id].onExit.dispose(); + delete terminals[id]; - callback() - }) + callback(); + }); socket.on( "generateCode", @@ -389,70 +391,80 @@ io.on("connection", async (socket) => { callback ) => { // Log code generation credit in DB - const fetchPromise = fetch(`https://database.ishaan1013.workers.dev/api/sandbox/generate`, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - userId: data.userId, - }), - }) + const fetchPromise = fetch( + `https://database.ishaan1013.workers.dev/api/sandbox/generate`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + userId: data.userId, + }), + } + ); + + console.log("CF_AI_KEY", process.env.CF_AI_KEY); // Generate code from cloudflare workers AI - const generateCodePromise = generateCode({ - fileName, - code, - line, - instructions, - }) + const generateCodePromise = fetch( + `https://ai.ishaan1013.workers.dev/api?fileName=${fileName}&code=${code}&line=${line}&instructions=${instructions}`, + { + headers: { + "Content-Type": "application/json", + Authorization: `${process.env.CF_AI_KEY}`, + }, + } + ); const [fetchResponse, generateCodeResponse] = await Promise.all([ fetchPromise, generateCodePromise, - ]) + ]); - const json = await generateCodeResponse.json() - callback(json) + const json = await generateCodeResponse.json(); + callback(json); } - ) + ); socket.on("disconnect", async () => { - console.log("disconnected", data.userId, data.sandboxId) + console.log("disconnected", data.userId, data.sandboxId); if (data.isOwner) { // console.log("deleting all terminals") Object.entries(terminals).forEach((t) => { - const { terminal, onData, onExit } = t[1] - onData.dispose() - onExit.dispose() - delete terminals[t[0]] - }) + const { terminal, onData, onExit } = t[1]; + onData.dispose(); + onExit.dispose(); + delete terminals[t[0]]; + }); - socket.broadcast.emit("disableAccess", "The sandbox owner has disconnected.") + socket.broadcast.emit( + "disableAccess", + "The sandbox owner has disconnected." + ); } - const sockets = await io.fetchSockets() + const sockets = await io.fetchSockets(); if (inactivityTimeout) { - clearTimeout(inactivityTimeout) - }; + clearTimeout(inactivityTimeout); + } if (sockets.length === 0) { - console.log("STARTING TIMER") + console.log("STARTING TIMER"); inactivityTimeout = setTimeout(() => { io.fetchSockets().then(async (sockets) => { if (sockets.length === 0) { // close server console.log("Closing server due to inactivity."); - const res = await stopServer(data.sandboxId, data.userId) + const res = await stopServer(data.sandboxId, data.userId); } - }); + }); }, 20000); } else { - console.log("number of sockets", sockets.length) + console.log("number of sockets", sockets.length); } - - }) -}) + }); +}); httpServer.listen(port, () => { - console.log(`Server, running on port ${port}`) -}) + console.log(`Server, running on port ${port}`); +}); diff --git a/backend/server/src/utils.ts b/backend/server/src/utils.ts index 851c703..ef99153 100644 --- a/backend/server/src/utils.ts +++ b/backend/server/src/utils.ts @@ -151,28 +151,6 @@ export const getProjectSize = async (id: string) => { return (await res.json()).size; }; -export const generateCode = async ({ - fileName, - code, - line, - instructions, -}: { - fileName: string; - code: string; - line: number; - instructions: string; -}) => { - return await fetch( - `https://ai.ishaan1013.workers.dev/api?fileName=${fileName}&code=${code}&line=${line}&instructions=${instructions}`, - { - headers: { - "Content-Type": "application/json", - Authorization: `${process.env.CF_AI_KEY}`, - }, - } - ); -}; - export const stopServer = async (sandboxId: string, userId: string) => { const res = await fetch("http://localhost:4001/stop", { method: "POST", diff --git a/frontend/components/editor/generate.tsx b/frontend/components/editor/generate.tsx index 1667253..52e1e28 100644 --- a/frontend/components/editor/generate.tsx +++ b/frontend/components/editor/generate.tsx @@ -49,8 +49,8 @@ export default function GenerateInput({ useEffect(() => { setTimeout(() => { inputRef.current?.focus(); - }, 0); - }, []); + }, 100); + }, [inputRef.current]); const handleGenerate = async ({ regenerate = false,