From fd34dfff67cde0d1cb4ee7627772654d9a03a8df Mon Sep 17 00:00:00 2001 From: ribawaja Date: Sun, 1 Jan 2023 20:30:48 -0500 Subject: [PATCH] add eager generate feature --- js/ui/tool/dream.js | 82 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/js/ui/tool/dream.js b/js/ui/tool/dream.js index adb6224..c345e2a 100644 --- a/js/ui/tool/dream.js +++ b/js/ui/tool/dream.js @@ -3,6 +3,7 @@ let generationQueue = []; let generationAreas = new Set(); let generating = false; + /** * Starts progress monitoring bar * @@ -375,12 +376,16 @@ const _generate = async (endpoint, request, bb, options = {}) => { }); }; + const sendInterrupt = () => { + fetch(`${host}${config.api.path}interrupt`, {method: "POST"}); + } + // Add Interrupt Button const interruptButton = makeElement("button", bb.x + bb.w - 100, bb.y + bb.h); interruptButton.classList.add("dream-stop-btn"); interruptButton.textContent = "Interrupt"; interruptButton.addEventListener("click", () => { - fetch(`${host}${config.api.path}interrupt`, {method: "POST"}); + sendInterrupt(); interruptButton.disabled = true; }); const marchingOptions = {}; @@ -390,6 +395,10 @@ const _generate = async (endpoint, request, bb, options = {}) => { console.info(`[dream] Generating images for prompt '${request.prompt}'`); console.debug(request); + + eagerGenerateCount = toolbar._current_tool.state.eagerGenerateCount; + isDreamComplete = false; + let stopProgress = null; try { let stopDrawingStatus = false; @@ -428,6 +437,17 @@ const _generate = async (endpoint, request, bb, options = {}) => { imageCollection.inputElement.removeChild(interruptButton); } + const needMoreGenerations = () => { + return (eagerGenerateCount > 0) && + (images.length - highestNavigatedImageIndex <= eagerGenerateCount); + } + + const isGenerationPending = () => { + return generationQueue.length > 0; + } + + let highestNavigatedImageIndex = 0; + // Image navigation const prevImg = () => { at--; @@ -443,10 +463,16 @@ const _generate = async (endpoint, request, bb, options = {}) => { at++; if (at >= images.length) at = 0; + highestNavigatedImageIndex = Math.max(at, highestNavigatedImageIndex); + imageindextxt.textContent = `${at}/${images.length - 1}`; var seed = seeds[at]; seedbtn.title = "Use seed " + seed; redraw(); + + if (needMoreGenerations() && !isGenerationPending()) { + makeMore(); + } }; const applyImg = async () => { @@ -504,6 +530,12 @@ const _generate = async (endpoint, request, bb, options = {}) => { } nextQueue(moreQ); + + //Start the next batch if we're eager-generating + if (needMoreGenerations() && !isGenerationPending() && !isDreamComplete) { + makeMore(); + } + }; const discardImg = async () => { @@ -657,6 +689,10 @@ const _generate = async (endpoint, request, bb, options = {}) => { mouse.listen.world.btn.right.onclick.clear(oncancelhandler); mouse.listen.world.btn.middle.onclick.clear(onmorehandler); mouse.listen.world.onwheel.clear(onwheelhandler); + isDreamComplete = true; + if (generating) { + sendInterrupt(); + } }; redraw(); @@ -740,6 +776,12 @@ const _generate = async (endpoint, request, bb, options = {}) => { imageSelectMenu.appendChild(seedbtn); nextQueue(initialQ); + + //Start the next batch after the initial generation + if (needMoreGenerations()) { + makeMore(); + } + }; /** @@ -1186,6 +1228,7 @@ const dreamTool = () => state.keepUnmaskedBlur = 8; state.overMaskPx = 20; state.preserveMasks = false; + state.eagerGenerateCount = 0; state.erasePrevCursor = () => uiCtx.clearRect(0, 0, uiCanvas.width, uiCanvas.height); @@ -1465,6 +1508,7 @@ const dreamTool = () => "Preserve Brushed Masks" ).label; + // Overmasking Slider state.ctxmenu.overMaskPxLabel = _toolbar_input.slider( state, @@ -1477,6 +1521,22 @@ const dreamTool = () => textStep: 1, } ).slider; + + + // Eager generation Slider + state.ctxmenu.eagerGenerateCountLabel = _toolbar_input.slider( + state, + "eagerGenerateCount", + "Generate-ahead count", + { + min: 0, + max: 100, + step: 2, + textStep: 1, + } + ).slider; + + } menu.appendChild(state.ctxmenu.cursorSizeSlider); @@ -1489,6 +1549,8 @@ const dreamTool = () => menu.appendChild(state.ctxmenu.preserveMasksLabel); menu.appendChild(document.createElement("br")); menu.appendChild(state.ctxmenu.overMaskPxLabel); + menu.appendChild(document.createElement("br")); + menu.appendChild(state.ctxmenu.eagerGenerateCountLabel); }, shortcut: "D", } @@ -1573,6 +1635,7 @@ const img2imgTool = () => state.keepUnmaskedBlur = 8; state.fullResolution = false; state.preserveMasks = false; + state.eagerGenerateCount = 0; state.denoisingStrength = 0.7; @@ -2006,6 +2069,19 @@ const img2imgTool = () => textStep: 1, } ).slider; + + // Eager generation Slider + state.ctxmenu.eagerGenerateCountLabel = _toolbar_input.slider( + state, + "eagerGenerateCount", + "Generate-ahead count", + { + min: 0, + max: 100, + step: 2, + textStep: 1, + } + ).slider; } menu.appendChild(state.ctxmenu.cursorSizeSlider); @@ -2023,6 +2099,8 @@ const img2imgTool = () => menu.appendChild(state.ctxmenu.denoisingStrengthSlider); menu.appendChild(state.ctxmenu.borderMaskGradientCheckbox); menu.appendChild(state.ctxmenu.borderMaskSlider); + menu.appendChild(document.createElement("br")); + menu.appendChild(state.ctxmenu.eagerGenerateCountLabel); }, shortcut: "I", } @@ -2031,7 +2109,7 @@ const img2imgTool = () => window.onbeforeunload = async () => { // Stop current generation on page close if (generating) - await fetch(`${host}${config.api.path}interrupt`, {method: "POST"}); + await sendInterrupt(); }; const sendSeed = (seed) => {