mirror of
https://github.com/cocktailpeanut/dalai
synced 2025-02-25 22:08:07 +01:00
212 lines
5.9 KiB
JavaScript
212 lines
5.9 KiB
JavaScript
const os = require('os');
|
|
const pty = require('node-pty');
|
|
const path = require('path');
|
|
const fs = require("fs");
|
|
const { createServer } = require("http");
|
|
const { Server } = require("socket.io");
|
|
const { io } = require("socket.io-client");
|
|
const term = require( 'terminal-kit' ).terminal;
|
|
const Downloader = require("nodejs-file-downloader");
|
|
const shell = os.platform() === 'win32' ? 'powershell.exe' : 'bash';
|
|
class Dalai {
|
|
constructor(url) {
|
|
if (url) this.url = url
|
|
this.home = path.resolve(os.homedir(), "dalai")
|
|
try {
|
|
fs.mkdirSync(this.home, { recursive: true })
|
|
} catch (e) { }
|
|
this.config = {
|
|
name: 'xterm-color',
|
|
cols: 80,
|
|
rows: 30,
|
|
}
|
|
}
|
|
async download(model) {
|
|
const num = {
|
|
"7B": 1,
|
|
"13B": 2,
|
|
"30B": 4,
|
|
"65B": 8,
|
|
}
|
|
const files = ["checklist.chk", "params.json"]
|
|
for(let i=0; i<num[model]; i++) {
|
|
files.push(`consolidated.0${i}.pth`)
|
|
}
|
|
const resolvedPath = path.resolve(this.home, "models", model)
|
|
await fs.promises.mkdir(resolvedPath, { recursive: true }).catch((e) => { })
|
|
|
|
for(let file of files) {
|
|
const task = `downloading ${file}`
|
|
const downloader = new Downloader({
|
|
url: `https://agi.gpt4.org/llama/LLaMA/${model}/${file}`,
|
|
directory: path.resolve(this.home, "models", model),
|
|
onProgress: (percentage, chunk, remainingSize) => {
|
|
this.progress(task, percentage)
|
|
},
|
|
});
|
|
try {
|
|
await this.startProgress(task)
|
|
await downloader.download();
|
|
} catch (error) {
|
|
console.log(error);
|
|
}
|
|
this.progressBar.update(1);
|
|
term("\n")
|
|
}
|
|
|
|
const files2 = ["tokenizer_checklist.chk", "tokenizer.model"]
|
|
for(let file of files2) {
|
|
const task = `downloading ${file}`
|
|
const downloader = new Downloader({
|
|
url: `https://agi.gpt4.org/llama/LLaMA/${file}`,
|
|
directory: path.resolve(this.home, "models"),
|
|
onProgress: (percentage, chunk, remainingSize) => {
|
|
this.progress(task, percentage)
|
|
},
|
|
});
|
|
try {
|
|
await this.startProgress(task)
|
|
await downloader.download();
|
|
} catch (error) {
|
|
console.log(error);
|
|
}
|
|
this.progressBar.update(1);
|
|
term("\n")
|
|
}
|
|
|
|
}
|
|
async install(...models) {
|
|
// install to ~/llama.cpp
|
|
await this.exec("pip3 install torch torchvision torchaudio sentencepiece numpy")
|
|
await this.exec("pip install torch torchvision torchaudio sentencepiece numpy")
|
|
await this.exec("git clone https://github.com/ggerganov/llama.cpp.git dalai", os.homedir())
|
|
await this.exec("make", this.home)
|
|
for(let model of models) {
|
|
await this.download(model)
|
|
await this.exec(`python3 convert-pth-to-ggml.py models/${model}/ 1`, this.home)
|
|
await this.quantize(model)
|
|
}
|
|
}
|
|
serve(port) {
|
|
const httpServer = createServer();
|
|
const io = new Server(httpServer)
|
|
io.on("connection", (socket) => {
|
|
socket.on('request', async (req) => {
|
|
await this.query(req, (str) => {
|
|
io.emit("result", { response: str, request: req })
|
|
})
|
|
});
|
|
});
|
|
httpServer.listen(port)
|
|
}
|
|
http(httpServer) {
|
|
const io = new Server(httpServer)
|
|
io.on("connection", (socket) => {
|
|
socket.on('request', async (req) => {
|
|
await this.query(req, (str) => {
|
|
io.emit("result", { response: str, request: req })
|
|
})
|
|
});
|
|
});
|
|
}
|
|
async request(req, cb) {
|
|
if (this.url) {
|
|
await this.connect(req, cb)
|
|
} else {
|
|
await this.query(req, cb)
|
|
}
|
|
}
|
|
async query(req, cb) {
|
|
let o = {
|
|
seed: req.seed || -1,
|
|
threads: req.threads || 8,
|
|
n_predict: req.n_predict || 128,
|
|
model: `./models/${req.model || "7B"}/ggml-model-q4_0.bin`
|
|
}
|
|
if (req.top_k) o.top_k = req.top_k
|
|
if (req.top_p) o.top_p = req.top_p
|
|
if (req.temp) o.temp = req.temp
|
|
if (req.batch_size) o.batch_size = req.batch_size
|
|
|
|
let chunks = []
|
|
for(let key in o) {
|
|
chunks.push(`--${key} ${o[key]}`)
|
|
}
|
|
chunks.push(`-p "${req.prompt}"`)
|
|
|
|
if (req.full) {
|
|
await this.exec(`./main ${chunks.join(" ")}`, this.home, cb)
|
|
} else {
|
|
const startpattern = /.*sampling parameters:.*/g
|
|
const endpattern = /.*mem per token.*/g
|
|
let started = false
|
|
let ended = false
|
|
await this.exec(`./main ${chunks.join(" ")}`, this.home, (msg) => {
|
|
if (endpattern.test(msg)) ended = true
|
|
if (started && !ended) {
|
|
cb(msg)
|
|
}
|
|
if (startpattern.test(msg)) started = true
|
|
})
|
|
}
|
|
}
|
|
connect(req, cb) {
|
|
const socket = io(this.url)
|
|
socket.emit('request', req)
|
|
socket.on('response', cb)
|
|
socket.on('error', function(e) {
|
|
throw e
|
|
});
|
|
}
|
|
exec(cmd, cwd, cb) {
|
|
return new Promise((resolve, reject) => {
|
|
const config = Object.assign({}, this.config)
|
|
if (cwd) {
|
|
config.cwd = path.resolve(cwd)
|
|
}
|
|
const ptyProcess = pty.spawn(shell, [], config)
|
|
ptyProcess.onData((data) => {
|
|
if (cb) {
|
|
cb(data)
|
|
} else {
|
|
process.stdout.write(data);
|
|
}
|
|
});
|
|
ptyProcess.onExit((res) => {
|
|
resolve(res)
|
|
});
|
|
ptyProcess.write(`${cmd}\r`)
|
|
ptyProcess.write("exit\r")
|
|
})
|
|
}
|
|
async quantize(model) {
|
|
let num = {
|
|
"7B": 1,
|
|
"13B": 2,
|
|
"30B": 4,
|
|
"65B": 8,
|
|
}
|
|
for(let i=0; i<num[model]; i++) {
|
|
const suffix = (i === 0 ? "" : `.${i}`)
|
|
await this.exec(`./quantize ./models/${model}/ggml-model-f16.bin ./models/${model}/ggml-model-q4_0.bin${suffix} 2`, this.home)
|
|
}
|
|
}
|
|
progress(task, percent) {
|
|
this.progressBar.update(percent/100);
|
|
//if (percent >= 100) {
|
|
// setTimeout(() => {
|
|
// term("\n")
|
|
// }, 200)
|
|
//}
|
|
}
|
|
startProgress(title) {
|
|
this.progressBar = term.progressBar({
|
|
width: 120,
|
|
title,
|
|
eta: true ,
|
|
percent: true
|
|
});
|
|
}
|
|
}
|
|
module.exports = Dalai
|