diff --git a/src/browser/components/ChatInput/CreationControls.tsx b/src/browser/components/ChatInput/CreationControls.tsx index dd231843b..b10fbe27d 100644 --- a/src/browser/components/ChatInput/CreationControls.tsx +++ b/src/browser/components/ChatInput/CreationControls.tsx @@ -10,6 +10,8 @@ interface CreationControlsProps { runtimeMode: RuntimeMode; sshHost: string; onRuntimeChange: (mode: RuntimeMode, host: string) => void; + autoRebaseTrunk: boolean; + onAutoRebaseChange: (value: boolean) => void; disabled: boolean; } @@ -74,6 +76,27 @@ export function CreationControls(props: CreationControlsProps) { + + {/* Auto Rebase Toggle */} +
+ + + ? + + Fetches origin/{props.trunkBranch || "main"} and rebases the new workspace before + running any init hooks. Disable if you need to stay on the local trunk snapshot. + + +
); } diff --git a/src/browser/components/ChatInput/index.tsx b/src/browser/components/ChatInput/index.tsx index 20bbb1827..6f81143e9 100644 --- a/src/browser/components/ChatInput/index.tsx +++ b/src/browser/components/ChatInput/index.tsx @@ -1021,6 +1021,8 @@ export const ChatInput: React.FC = (props) => { runtimeMode={creationState.runtimeMode} sshHost={creationState.sshHost} onRuntimeChange={creationState.setRuntimeOptions} + autoRebaseTrunk={creationState.autoRebaseTrunk} + onAutoRebaseChange={creationState.setAutoRebaseTrunk} disabled={creationState.isSending || isSending} /> )} diff --git a/src/browser/components/ChatInput/useCreationWorkspace.ts b/src/browser/components/ChatInput/useCreationWorkspace.ts index 89c0ba4ac..021af168b 100644 --- a/src/browser/components/ChatInput/useCreationWorkspace.ts +++ b/src/browser/components/ChatInput/useCreationWorkspace.ts @@ -39,6 +39,8 @@ interface UseCreationWorkspaceReturn { runtimeMode: RuntimeMode; sshHost: string; setRuntimeOptions: (mode: RuntimeMode, host: string) => void; + autoRebaseTrunk: boolean; + setAutoRebaseTrunk: (value: boolean) => void; error: string | null; setError: (error: string | null) => void; isSending: boolean; @@ -60,6 +62,7 @@ export function useCreationWorkspace({ const [recommendedTrunk, setRecommendedTrunk] = useState(null); const [error, setError] = useState(null); const [isSending, setIsSending] = useState(false); + const [autoRebaseTrunk, setAutoRebaseTrunk] = useState(true); // Centralized draft workspace settings with automatic persistence const { settings, setRuntimeOptions, setTrunkBranch, getRuntimeString } = @@ -107,6 +110,7 @@ export function useCreationWorkspace({ runtimeConfig, projectPath, // Pass projectPath when workspaceId is null trunkBranch: settings.trunkBranch, // Pass selected trunk branch from settings + autoRebaseTrunk, }); if (!result.success) { @@ -139,6 +143,7 @@ export function useCreationWorkspace({ getRuntimeString, sendMessageOptions, settings.trunkBranch, + autoRebaseTrunk, ] ); @@ -149,6 +154,8 @@ export function useCreationWorkspace({ runtimeMode: settings.runtimeMode, sshHost: settings.sshHost, setRuntimeOptions, + autoRebaseTrunk, + setAutoRebaseTrunk, error, setError, isSending, diff --git a/src/common/types/ipc.ts b/src/common/types/ipc.ts index 339da510d..82d3ff163 100644 --- a/src/common/types/ipc.ts +++ b/src/common/types/ipc.ts @@ -217,6 +217,8 @@ export interface SendMessageOptions { maxOutputTokens?: number; providerOptions?: MuxProviderOptions; mode?: string; // Mode name - frontend narrows to specific values, backend accepts any string + /** Whether to rebase onto origin/ before first workspace init */ + autoRebaseTrunk?: boolean; muxMetadata?: MuxFrontendMetadata; // Frontend-defined metadata, backend treats as black-box } diff --git a/src/node/runtime/LocalRuntime.test.ts b/src/node/runtime/LocalRuntime.test.ts index a3cbdd053..638fe254e 100644 --- a/src/node/runtime/LocalRuntime.test.ts +++ b/src/node/runtime/LocalRuntime.test.ts @@ -2,6 +2,9 @@ import { describe, expect, it } from "bun:test"; import * as os from "os"; import * as path from "path"; import { LocalRuntime } from "./LocalRuntime"; +import * as fsPromises from "fs/promises"; +import { execFileSync } from "child_process"; +import type { InitLogger } from "./Runtime"; describe("LocalRuntime constructor", () => { it("should expand tilde in srcBaseDir", () => { @@ -65,3 +68,97 @@ describe("LocalRuntime.resolvePath", () => { expect(path.isAbsolute(resolved)).toBe(true); }); }); + +const GIT_ENV = { + ...process.env, + GIT_AUTHOR_NAME: "Test User", + GIT_AUTHOR_EMAIL: "test@example.com", + GIT_COMMITTER_NAME: "Test User", + GIT_COMMITTER_EMAIL: "test@example.com", +}; + +function runGit(args: string[], cwd?: string) { + execFileSync("git", args, { cwd, env: GIT_ENV }); +} + +function gitOutput(args: string[], cwd?: string): string { + return execFileSync("git", args, { cwd, env: GIT_ENV }).toString().trim(); +} + +function createTestInitLogger(): InitLogger { + const logs: string[] = []; + return { + logStep: (m: string) => { + logs.push(`[step] ${m}`); + }, + logStdout: (line: string) => { + if (line) logs.push(`[out] ${line}`); + }, + logStderr: (line: string) => { + if (line) logs.push(`[err] ${line}`); + }, + logComplete: (code: number) => { + logs.push(`[done] ${code}`); + }, + }; +} + +describe("LocalRuntime auto rebase", () => { + it("rebases onto origin when enabled", async () => { + const tmpRoot = await fsPromises.mkdtemp(path.join(os.tmpdir(), "local-runtime-")); + const originDir = path.join(tmpRoot, "origin.git"); + const projectDir = path.join(tmpRoot, "project"); + const upstreamDir = path.join(tmpRoot, "upstream"); + const workspacesDir = path.join(tmpRoot, "workspaces"); + const trunkBranch = "main"; + + try { + runGit(["init", "--bare", originDir]); + + await fsPromises.mkdir(projectDir, { recursive: true }); + runGit(["init", "-b", trunkBranch], projectDir); + runGit(["remote", "add", "origin", originDir], projectDir); + + await fsPromises.writeFile(path.join(projectDir, "README.md"), "first\n"); + runGit(["add", "README.md"], projectDir); + runGit(["commit", "-m", "initial"], projectDir); + runGit(["push", "-u", "origin", trunkBranch], projectDir); + + runGit(["clone", "-b", trunkBranch, originDir, upstreamDir]); + await fsPromises.appendFile(path.join(upstreamDir, "README.md"), "second\n"); + runGit(["commit", "-am", "upstream change"], upstreamDir); + runGit(["push", "origin", trunkBranch], upstreamDir); + + const runtime = new LocalRuntime(workspacesDir); + const initLogger = createTestInitLogger(); + const branchName = "auto-rebase-test"; + + const createResult = await runtime.createWorkspace({ + projectPath: projectDir, + branchName, + trunkBranch, + directoryName: branchName, + initLogger, + }); + + expect(createResult.success).toBe(true); + expect(createResult.workspacePath).toBeTruthy(); + const workspacePath = createResult.workspacePath!; + + await runtime.initWorkspace({ + projectPath: projectDir, + branchName, + trunkBranch, + workspacePath, + initLogger, + autoRebaseTrunk: true, + }); + + const workspaceHead = gitOutput(["rev-parse", "HEAD"], workspacePath); + const originHead = gitOutput(["rev-parse", `origin/${trunkBranch}`], projectDir); + expect(workspaceHead).toBe(originHead); + } finally { + await fsPromises.rm(tmpRoot, { recursive: true, force: true }); + } + }); +}); diff --git a/src/node/runtime/LocalRuntime.ts b/src/node/runtime/LocalRuntime.ts index 81012cd12..1bd5a5287 100644 --- a/src/node/runtime/LocalRuntime.ts +++ b/src/node/runtime/LocalRuntime.ts @@ -375,9 +375,13 @@ export class LocalRuntime implements Runtime { } async initWorkspace(params: WorkspaceInitParams): Promise { - const { projectPath, workspacePath, initLogger } = params; + const { projectPath, workspacePath, initLogger, trunkBranch, autoRebaseTrunk } = params; try { + if (autoRebaseTrunk && trunkBranch) { + await this.runAutoRebase(workspacePath, trunkBranch, initLogger); + } + // Run .mux/init hook if it exists // Note: runInitHook calls logComplete() internally if hook exists const hookExists = await checkInitHookExists(projectPath); @@ -399,6 +403,46 @@ export class LocalRuntime implements Runtime { } } + private async runAutoRebase( + workspacePath: string, + trunkBranch: string, + initLogger: InitLogger + ): Promise { + const quote = (value: string) => `"${value}"`; + + const hasOrigin = await (async () => { + try { + using remoteCheck = execAsync(`git -C ${quote(workspacePath)} remote get-url origin`); + const { stdout } = await remoteCheck.result; + return stdout.trim().length > 0; + } catch { + return false; + } + })(); + + if (!hasOrigin) { + initLogger.logStep("Skipping auto-rebase: origin remote not configured."); + return; + } + + initLogger.logStep(`Fetching origin/${trunkBranch}...`); + try { + using fetchProc = execAsync(`git -C ${quote(workspacePath)} fetch origin ${trunkBranch}`); + await fetchProc.result; + } catch (error) { + throw new Error(`Failed to fetch origin/${trunkBranch}: ${getErrorMessage(error)}`); + } + + initLogger.logStep(`Rebasing onto origin/${trunkBranch}...`); + try { + using rebaseProc = execAsync(`git -C ${quote(workspacePath)} rebase origin/${trunkBranch}`); + await rebaseProc.result; + initLogger.logStep(`Rebased onto origin/${trunkBranch}`); + } catch (error) { + throw new Error(`Failed to rebase onto origin/${trunkBranch}: ${getErrorMessage(error)}`); + } + } + /** * Run .mux/init hook if it exists and is executable */ diff --git a/src/node/runtime/Runtime.ts b/src/node/runtime/Runtime.ts index 4e01a0ceb..d1c9c7445 100644 --- a/src/node/runtime/Runtime.ts +++ b/src/node/runtime/Runtime.ts @@ -151,6 +151,8 @@ export interface WorkspaceInitParams { initLogger: InitLogger; /** Optional abort signal for cancellation */ abortSignal?: AbortSignal; + /** Automatically fetch/rebase onto origin/ before running init hooks */ + autoRebaseTrunk?: boolean; } /** diff --git a/src/node/runtime/SSHRuntime.ts b/src/node/runtime/SSHRuntime.ts index 22588e873..81e33318e 100644 --- a/src/node/runtime/SSHRuntime.ts +++ b/src/node/runtime/SSHRuntime.ts @@ -722,6 +722,72 @@ export class SSHRuntime implements Runtime { } } + private async runAutoRebase( + workspacePath: string, + trunkBranch: string, + initLogger: InitLogger, + abortSignal?: AbortSignal + ): Promise { + const remoteCheck = await this.exec(`git remote get-url origin`, { + cwd: workspacePath, + timeout: 30, + abortSignal, + }); + const [, _remoteStderr, remoteExitCode] = await Promise.all([ + streamToString(remoteCheck.stdout), + streamToString(remoteCheck.stderr), + remoteCheck.exitCode, + ]); + + if (remoteExitCode !== 0) { + initLogger.logStep("Skipping auto-rebase: origin remote not configured."); + return; + } + + await this.runGitCommandWithLogging( + workspacePath, + `git fetch origin ${trunkBranch}`, + `Fetching origin/${trunkBranch}...`, + initLogger, + abortSignal + ); + + await this.runGitCommandWithLogging( + workspacePath, + `git rebase origin/${trunkBranch}`, + `Rebasing onto origin/${trunkBranch}...`, + initLogger, + abortSignal + ); + + initLogger.logStep(`Rebased onto origin/${trunkBranch}`); + } + + private async runGitCommandWithLogging( + workspacePath: string, + command: string, + description: string, + initLogger: InitLogger, + abortSignal?: AbortSignal, + timeout = 300 + ): Promise { + initLogger.logStep(description); + const stream = await this.exec(command, { + cwd: workspacePath, + timeout, + abortSignal, + }); + const [stdout, stderr, exitCode] = await Promise.all([ + streamToString(stream.stdout), + streamToString(stream.stderr), + stream.exitCode, + ]); + + if (exitCode !== 0) { + throw new Error(stderr.trim() || stdout.trim() || description); + } + } + /** * Run .mux/init hook on remote machine if it exists */ @@ -906,6 +972,10 @@ export class SSHRuntime implements Runtime { } initLogger.logStep("Branch checked out successfully"); + if (params.autoRebaseTrunk && trunkBranch) { + await this.runAutoRebase(workspacePath, trunkBranch, initLogger, abortSignal); + } + // 3. Run .mux/init hook if it exists // Note: runInitHook calls logComplete() internally if hook exists const hookExists = await checkInitHookExists(projectPath); diff --git a/src/node/services/ipcMain.ts b/src/node/services/ipcMain.ts index d76819023..fade4a582 100644 --- a/src/node/services/ipcMain.ts +++ b/src/node/services/ipcMain.ts @@ -267,6 +267,7 @@ export class IpcMain { trunkBranch: recommendedTrunk, workspacePath: createResult.workspacePath, initLogger, + autoRebaseTrunk: options.autoRebaseTrunk ?? false, }) .catch((error: unknown) => { const errorMsg = error instanceof Error ? error.message : String(error); @@ -560,6 +561,7 @@ export class IpcMain { trunkBranch: normalizedTrunkBranch, workspacePath: createResult.workspacePath, initLogger, + autoRebaseTrunk: false, }) .catch((error: unknown) => { const errorMsg = error instanceof Error ? error.message : String(error);