Adding repeat penalty option

This commit is contained in:
Raven Scott 2023-04-03 02:28:20 +02:00
parent e4162bf03e
commit 72c683af7e
1 changed files with 25 additions and 2 deletions

View File

@ -42,8 +42,14 @@ module.exports = {
"description": "The higher the temperature, the more random the model output. A default 0.1 is used if not provided.",
"required": false,
"type": 3
},
{
"name": "repeat-penalty",
"description": "The higher the temperature, the more random the model output. A default 0.1 is used if not provided.",
"required": false,
"type": 3
}
],
],
@ -54,7 +60,8 @@ module.exports = {
let varsToCheck = [
{ name: "model", value: null },
{ name: "temperature", value: null },
{ name: "init-prompt", value: null }
{ name: "init-prompt", value: null },
{ name: "repeat-penalty", value: null }
];
for (let i = 0; i < options.length; i++) {
@ -71,6 +78,7 @@ module.exports = {
let userInputModel = varsToCheck.find(v => v.name === "model")?.value;
let userInputTemperature = varsToCheck.find(v => v.name === "temperature")?.value;
let userInputInitPrompt = varsToCheck.find(v => v.name === "init-prompt")?.value;
let userInputRepeatPenalty = varsToCheck.find(v => v.name === "repeat-penalty")?.value;
// Init Prompt Setting
if (userInputInitPrompt === null) {
@ -105,6 +113,21 @@ module.exports = {
}
}
// repeat setting
if (userInputRepeatPenalty === null) {
console.log("-- No RepeatPenalty provided, using default --")
} else {
const parsedRepeatPenalty = parseFloat(userInputRepeatPenalty);
if (parsedRepeatPenalty >= 0.1 && parsedRepeatPenalty <= 2) {
// temperature is within range
repeatPenalty = parsedRepeatPenalty;
} else {
// temperature is outside of range
return interaction.followUp(`Repeat Penalty must be between 0.1 and 2`);
}
}
var req = unirest('POST', apiUrl + '?model=' + model + '&temperature=' + temperature + '&top_k=' + topK + '&top_p=' + topP + '&max_length=' + maxLength + '&context_window=' + contextWindow + '&repeat_last_n=' + repeatLastN + '&repeat_penalty=' + repeatPenalty + '&init_prompt=' + initPrompt + '&n_threads=' + nThreads)
.headers({
'accept': 'application/json'