feat(server): improve session modify (#12928)

fix AI-248
This commit is contained in:
DarkSky 2025-06-25 20:02:21 +08:00 committed by GitHub
parent 697e0bf9ba
commit e32c9a814a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 282 additions and 270 deletions

View File

@ -16,7 +16,7 @@ Generated by [AVA](https://avajs.dev).
role: 'assistant',
},
],
tokens: 0,
tokens: 8,
},
]
@ -30,7 +30,7 @@ Generated by [AVA](https://avajs.dev).
role: 'assistant',
},
],
tokens: 0,
tokens: 8,
},
]

View File

@ -347,7 +347,7 @@ test('should be able to update chat session prompt', async t => {
t.truthy(sessionId, 'should create session');
// Update the session
const updatedSessionId = await session.updateSession({
const updatedSessionId = await session.update({
sessionId,
promptName: 'Search With AFFiNE AI',
userId,

View File

@ -1,6 +1,6 @@
import { randomUUID } from 'node:crypto';
import { AiSession, PrismaClient, User, Workspace } from '@prisma/client';
import { PrismaClient, User, Workspace } from '@prisma/client';
import ava, { TestFn } from 'ava';
import Sinon from 'sinon';
@ -43,7 +43,7 @@ test.before(async t => {
let user: User;
let workspace: Workspace;
let session: AiSession;
let sessionId: string;
let docId = 'doc1';
test.beforeEach(async t => {
@ -53,7 +53,7 @@ test.beforeEach(async t => {
email: 'test@affine.pro',
});
workspace = await t.context.workspace.create(user.id);
session = await t.context.copilotSession.create({
sessionId = await t.context.copilotSession.create({
sessionId: randomUUID(),
workspaceId: workspace.id,
docId,
@ -68,7 +68,7 @@ test.after(async t => {
});
test('should create a copilot context', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
const { id: contextId } = await t.context.copilotContext.create(sessionId);
t.truthy(contextId);
const context = await t.context.copilotContext.get(contextId);
@ -77,7 +77,7 @@ test('should create a copilot context', async t => {
const config = await t.context.copilotContext.getConfig(contextId);
t.is(config?.workspaceId, workspace.id, 'should get context config');
const context1 = await t.context.copilotContext.getBySessionId(session.id);
const context1 = await t.context.copilotContext.getBySessionId(sessionId);
t.is(context1?.id, contextId, 'should get context by session id');
});
@ -87,7 +87,7 @@ test('should get null for non-exist job', async t => {
});
test('should update context', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
const { id: contextId } = await t.context.copilotContext.create(sessionId);
const config = await t.context.copilotContext.getConfig(contextId);
const doc = {
@ -102,7 +102,7 @@ test('should update context', async t => {
});
test('should insert embedding by doc id', async t => {
const { id: contextId } = await t.context.copilotContext.create(session.id);
const { id: contextId } = await t.context.copilotContext.create(sessionId);
{
await t.context.copilotContext.insertFileEmbedding(contextId, 'file-id', [

View File

@ -6,7 +6,7 @@ import ava, { ExecutionContext, TestFn } from 'ava';
import { CopilotPromptInvalid, CopilotSessionInvalidInput } from '../../base';
import {
CopilotSessionModel,
UpdateChatSessionData,
UpdateChatSessionOptions,
UserModel,
WorkspaceModel,
} from '../../models';
@ -174,7 +174,10 @@ test('should check session validation for prompts', async t => {
sessionTypes.forEach(({ name, session }) => {
t.notThrows(
() =>
copilotSession.checkSessionPrompt(session, 'test-prompt', undefined),
copilotSession.checkSessionPrompt(session, {
name: 'test-prompt',
action: undefined,
}),
`${name} session should allow non-action prompts`
);
});
@ -195,14 +198,20 @@ test('should check session validation for prompts', async t => {
if (shouldThrow) {
t.throws(
() =>
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
copilotSession.checkSessionPrompt(session, {
name: 'action-prompt',
action: 'edit',
}),
{ instanceOf: CopilotPromptInvalid },
`${name} session should reject action prompts`
);
} else {
t.notThrows(
() =>
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
copilotSession.checkSessionPrompt(session, {
name: 'action-prompt',
action: 'edit',
}),
`${name} session should allow action prompts`
);
}
@ -323,14 +332,19 @@ test('should handle session updates and validations', async t => {
},
});
type UpdateData = Omit<UpdateChatSessionOptions, 'userId' | 'sessionId'>;
const assertUpdateThrows = async (
t: ExecutionContext<Context>,
sessionId: string,
updateData: UpdateChatSessionData,
updateData: UpdateData,
message: string
) => {
await t.throwsAsync(
t.context.copilotSession.update(user.id, sessionId, updateData),
t.context.copilotSession.update({
...updateData,
userId: user.id,
sessionId,
}),
{ instanceOf: CopilotSessionInvalidInput },
message
);
@ -339,11 +353,15 @@ test('should handle session updates and validations', async t => {
const assertUpdate = async (
t: ExecutionContext<Context>,
sessionId: string,
updateData: UpdateChatSessionData,
updateData: UpdateData,
message: string
) => {
await t.notThrowsAsync(
t.context.copilotSession.update(user.id, sessionId, updateData),
t.context.copilotSession.update({
...updateData,
userId: user.id,
sessionId,
}),
message
);
};
@ -386,7 +404,6 @@ test('should handle session updates and validations', async t => {
'forked session should reject docId update'
);
}
{
// case 3: prompt update validation
await assertUpdate(
@ -415,14 +432,13 @@ test('should handle session updates and validations', async t => {
await createTestSession(t, { sessionId: existingPinnedId, pinned: true });
// should unpin existing when pinning new session
await copilotSession.update(user.id, sessionId, { pinned: true });
await copilotSession.update({ userId: user.id, sessionId, pinned: true });
const sessionStatesAfterPin = await Promise.all([
getSessionState(db, sessionId),
getSessionState(db, existingPinnedId),
]);
t.snapshot(
sessionStatesAfterPin,
[
await getSessionState(db, sessionId),
await getSessionState(db, existingPinnedId),
],
'should unpin existing when pinning new session'
);
}
@ -430,11 +446,8 @@ test('should handle session updates and validations', async t => {
// test type conversions
{
const conversionSteps: any[] = [];
const convertSession = async (
step: string,
data: UpdateChatSessionData
) => {
await copilotSession.update(user.id, sessionId, data);
const convertSession = async (step: string, data: UpdateData) => {
await copilotSession.update({ ...data, userId: user.id, sessionId });
const session = await db.aiSession.findUnique({
where: { id: sessionId },
select: { docId: true, pinned: true },

View File

@ -9,6 +9,7 @@ import {
CopilotSessionInvalidInput,
CopilotSessionNotFound,
} from '../base';
import { getTokenEncoder } from '../native';
import { BaseModel } from './base';
export enum SessionType {
@ -17,6 +18,12 @@ export enum SessionType {
Doc = 'doc', // docId points to specific document
}
type ChatPrompt = {
name: string;
action?: string | null;
model: string;
};
type ChatAttachment = { attachment: string; mimeType: string } | string;
type ChatStreamObject = {
@ -38,7 +45,7 @@ type ChatMessage = {
createdAt: Date;
};
type ChatSession = {
type PureChatSession = {
sessionId: string;
workspaceId: string;
docId?: string | null;
@ -46,22 +53,44 @@ type ChatSession = {
messages?: ChatMessage[];
// connect ids
userId: string;
promptName: string;
promptAction: string | null;
parentSessionId?: string | null;
};
export type UpdateChatSessionData = Partial<
Pick<ChatSession, 'docId' | 'pinned' | 'promptName'>
>;
export type UpdateChatSession = Pick<ChatSession, 'userId' | 'sessionId'> &
UpdateChatSessionData;
type ChatSession = PureChatSession & {
// connect ids
promptName: string;
promptAction: string | null;
};
export type ListSessionOptions = {
type ChatSessionWithPrompt = PureChatSession & {
prompt: ChatPrompt;
};
type ChatSessionBaseState = Pick<ChatSession, 'userId' | 'sessionId'>;
export type ForkSessionOptions = Omit<
ChatSession,
'messages' | 'promptName' | 'promptAction'
> & {
prompt: { name: string; action: string | null | undefined; model: string };
messages: ChatMessage[];
};
type UpdateChatSessionMessage = ChatSessionBaseState & {
prompt: { model: string };
messages: ChatMessage[];
};
export type UpdateChatSessionOptions = ChatSessionBaseState &
Pick<Partial<ChatSession>, 'docId' | 'pinned' | 'promptName'>;
export type UpdateChatSession = ChatSessionBaseState & UpdateChatSessionOptions;
export type ListSessionOptions = Pick<
Partial<ChatSession>,
'sessionId' | 'workspaceId' | 'docId' | 'pinned'
> & {
userId: string;
sessionId?: string;
workspaceId?: string;
docId?: string;
action?: boolean;
fork?: boolean;
limit?: number;
@ -74,6 +103,13 @@ export type ListSessionOptions = {
withMessages?: boolean;
};
export type CleanupSessionOptions = Pick<
ChatSession,
'userId' | 'workspaceId' | 'docId'
> & {
sessionIds: string[];
};
@Injectable()
export class CopilotSessionModel extends BaseModel {
getSessionType(session: Pick<ChatSession, 'docId' | 'pinned'>): SessionType {
@ -84,10 +120,10 @@ export class CopilotSessionModel extends BaseModel {
checkSessionPrompt(
session: Pick<ChatSession, 'docId' | 'pinned'>,
promptName: string,
promptAction: string | undefined
prompt: Partial<ChatPrompt>
): boolean {
const sessionType = this.getSessionType(session);
const { name: promptName, action: promptAction } = prompt;
// workspace and pinned sessions cannot use action prompts
if (
@ -110,12 +146,18 @@ export class CopilotSessionModel extends BaseModel {
}
@Transactional()
async create(state: ChatSession) {
async create(state: ChatSession, reuseChat = false): Promise<string> {
// find and return existing session if session is chat session
if (reuseChat && !state.promptAction) {
const sessionId = await this.find(state);
if (sessionId) return sessionId;
}
if (state.pinned) {
await this.unpin(state.workspaceId, state.userId);
}
const row = await this.db.aiSession.create({
const session = await this.db.aiSession.create({
data: {
id: state.sessionId,
workspaceId: state.workspaceId,
@ -127,8 +169,46 @@ export class CopilotSessionModel extends BaseModel {
promptAction: state.promptAction,
parentSessionId: state.parentSessionId,
},
select: { id: true },
});
return row;
return session.id;
}
@Transactional()
async createWithPrompt(
state: ChatSessionWithPrompt,
reuseChat = false
): Promise<string> {
const { prompt, ...rest } = state;
return await this.models.copilotSession.create(
{ ...rest, promptName: prompt.name, promptAction: prompt.action ?? null },
reuseChat
);
}
@Transactional()
async fork(options: ForkSessionOptions): Promise<string> {
if (!options.messages?.length) {
throw new CopilotSessionInvalidInput(
'Cannot fork session without messages'
);
}
if (options.pinned) {
await this.unpin(options.workspaceId, options.userId);
}
const { messages, ...forkedState } = options;
// create session
const sessionId = await this.createWithPrompt({
...forkedState,
messages: [],
});
// save message
await this.models.copilotSession.updateMessages({
...forkedState,
messages,
});
return sessionId;
}
@Transactional()
@ -143,9 +223,7 @@ export class CopilotSessionModel extends BaseModel {
}
@Transactional()
async getChatSessionId(
state: Omit<ChatSession, 'promptName' | 'promptAction'>
) {
async find(state: PureChatSession) {
const extraCondition: Record<string, any> = {};
if (state.parentSessionId) {
// also check session id if provided session is forked session
@ -287,11 +365,8 @@ export class CopilotSessionModel extends BaseModel {
}
@Transactional()
async update(
userId: string,
sessionId: string,
data: UpdateChatSessionData
): Promise<string> {
async update(options: UpdateChatSessionOptions): Promise<string> {
const { userId, sessionId, docId, promptName, pinned } = options;
const session = await this.getExists(
sessionId,
{
@ -313,33 +388,71 @@ export class CopilotSessionModel extends BaseModel {
throw new CopilotSessionInvalidInput(
`Cannot update action: ${session.id}`
);
} else if (data.docId && session.parentSessionId) {
} else if (docId && session.parentSessionId) {
throw new CopilotSessionInvalidInput(
`Cannot update docId for forked session: ${session.id}`
);
}
if (data.promptName) {
if (promptName) {
const prompt = await this.db.aiPrompt.findFirst({
where: { name: data.promptName },
where: { name: promptName },
});
// always not allow to update to action prompt
if (!prompt || prompt.action) {
throw new CopilotSessionInvalidInput(
`Prompt ${data.promptName} not found or not available for session ${sessionId}`
`Prompt ${promptName} not found or not available for session ${sessionId}`
);
}
}
if (data.pinned && data.pinned !== session.pinned) {
if (pinned && pinned !== session.pinned) {
// if pin the session, unpin exists session in the workspace
await this.unpin(session.workspaceId, userId);
}
await this.db.aiSession.update({ where: { id: sessionId }, data });
await this.db.aiSession.update({
where: { id: sessionId },
data: { docId, promptName, pinned },
});
return sessionId;
}
@Transactional()
async cleanup(options: CleanupSessionOptions): Promise<string[]> {
const sessions = await this.db.aiSession.findMany({
where: {
id: { in: options.sessionIds },
userId: options.userId,
workspaceId: options.workspaceId,
docId: options.docId,
deletedAt: null,
},
select: { id: true, prompt: true },
});
const sessionIds = sessions.map(({ id }) => id);
// cleanup all messages
await this.db.aiSessionMessage.deleteMany({
where: { sessionId: { in: sessionIds } },
});
// only mark action session as deleted
// chat session always can be reuse
const actionIds = sessions
.filter(({ prompt }) => !!prompt.action)
.map(({ id }) => id);
// 标记 action session 为已删除
if (actionIds.length > 0) {
await this.db.aiSession.updateMany({
where: { id: { in: actionIds } },
data: { pinned: false, deletedAt: new Date() },
});
}
return sessionIds;
}
@Transactional()
async getMessages(
sessionId: string,
@ -353,31 +466,42 @@ export class CopilotSessionModel extends BaseModel {
});
}
@Transactional()
async setMessages(
sessionId: string,
messages: ChatMessage[],
tokenCost: number
) {
await this.db.aiSessionMessage.createMany({
data: messages.map(m => ({
...m,
attachments: m.attachments || undefined,
params: omit(m.params, ['docs']) || undefined,
streamObjects: m.streamObjects || undefined,
sessionId,
})),
});
private calculateTokenSize(messages: any[], model: string): number {
const encoder = getTokenEncoder(model);
const content = messages.map(m => m.content).join('');
return encoder?.count(content) || 0;
}
// only count message generated by user
const userMessages = messages.filter(m => m.role === 'user');
await this.db.aiSession.update({
where: { id: sessionId },
data: {
messageCost: { increment: userMessages.length },
tokenCost: { increment: tokenCost },
},
});
@Transactional()
async updateMessages(state: UpdateChatSessionMessage) {
const { sessionId, userId, messages } = state;
const haveSession = await this.has(sessionId, userId);
if (!haveSession) {
throw new CopilotSessionNotFound();
}
if (messages.length) {
const tokenCost = this.calculateTokenSize(messages, state.prompt.model);
await this.db.aiSessionMessage.createMany({
data: messages.map(m => ({
...m,
attachments: m.attachments || undefined,
params: omit(m.params, ['docs']) || undefined,
streamObjects: m.streamObjects || undefined,
sessionId,
})),
});
// only count message generated by user
const userMessages = messages.filter(m => m.role === 'user');
await this.db.aiSession.update({
where: { id: sessionId },
data: {
messageCost: { increment: userMessages.length },
tokenCost: { increment: tokenCost },
},
});
}
}
@Transactional()
@ -404,4 +528,15 @@ export class CopilotSessionModel extends BaseModel {
await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
}
}
@Transactional()
async countUserMessages(userId: string): Promise<number> {
const sessions = await this.db.aiSession.findMany({
where: { userId },
select: { messageCost: true, prompt: { select: { action: true } } },
});
return sessions
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
.reduce((prev, cost) => prev + cost, 0);
}
}

View File

@ -1,4 +1,4 @@
import serverNativeModule from '@affine/server-native';
import serverNativeModule, { type Tokenizer } from '@affine/server-native';
export const mergeUpdatesInApplyWay = serverNativeModule.mergeUpdatesInApplyWay;
@ -16,10 +16,21 @@ export const mintChallengeResponse = async (resource: string, bits: number) => {
return serverNativeModule.mintChallengeResponse(resource, bits);
};
export function getTokenEncoder(model?: string | null): Tokenizer | null {
if (!model) return null;
if (model.startsWith('gpt')) {
return serverNativeModule.fromModelName(model);
} else if (model.startsWith('dall')) {
// dalle don't need to calc the token
return null;
} else {
// c100k based model
return serverNativeModule.fromModelName('gpt-4');
}
}
export const getMime = serverNativeModule.getMime;
export const parseDoc = serverNativeModule.parseDoc;
export const Tokenizer = serverNativeModule.Tokenizer;
export const fromModelName = serverNativeModule.fromModelName;
export const htmlSanitize = serverNativeModule.htmlSanitize;
export const AFFINE_PRO_PUBLIC_KEY = serverNativeModule.AFFINE_PRO_PUBLIC_KEY;
export const AFFINE_PRO_LICENSE_AES_KEY =

View File

@ -3,8 +3,8 @@ import { Logger } from '@nestjs/common';
import { AiPrompt } from '@prisma/client';
import Mustache from 'mustache';
import { getTokenEncoder } from '../../../native';
import { PromptConfig, PromptMessage, PromptParams } from '../providers';
import { getTokenEncoder } from '../types';
// disable escaping
Mustache.escape = (text: string) => text;
@ -56,8 +56,7 @@ export class ChatPrompt {
private readonly messages: PromptMessage[]
) {
this.encoder = getTokenEncoder(model);
this.promptTokenSize =
this.encoder?.count(messages.map(m => m.content).join('') || '') || 0;
this.promptTokenSize = this.encode(messages.map(m => m.content).join(''));
this.templateParamKeys = extractMustacheParams(
messages.map(m => m.content).join('')
);

View File

@ -39,15 +39,12 @@ import { PromptMessage, StreamObject } from './providers';
import { ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import {
AvailableModels,
type ChatHistory,
type ChatMessage,
type ChatSessionState,
SubmittedMessage,
} from './types';
registerEnumType(AvailableModels, { name: 'CopilotModel' });
export const COPILOT_LOCKER = 'copilot';
// ================== Input Types ==================
@ -301,8 +298,6 @@ class CopilotPromptMessageType {
params!: Record<string, string> | null;
}
registerEnumType(AvailableModels, { name: 'CopilotModels' });
@ObjectType()
class CopilotPromptType {
@Field(() => String)
@ -533,7 +528,7 @@ export class CopilotResolver {
}
await this.chatSession.checkQuota(user.id);
return await this.chatSession.updateSession({
return await this.chatSession.update({
...options,
userId: user.id,
});
@ -682,8 +677,8 @@ class CreateCopilotPromptInput {
@Field(() => String)
name!: string;
@Field(() => AvailableModels)
model!: AvailableModels;
@Field(() => String)
model!: string;
@Field(() => String, { nullable: true })
action!: string | null;

View File

@ -2,7 +2,7 @@ import { randomUUID } from 'node:crypto';
import { Injectable, Logger } from '@nestjs/common';
import { Transactional } from '@nestjs-cls/transactional';
import { AiPromptRole, PrismaClient } from '@prisma/client';
import { AiPromptRole } from '@prisma/client';
import {
CopilotActionTaken,
@ -14,10 +14,11 @@ import {
} from '../../base';
import { QuotaService } from '../../core/quota';
import {
CleanupSessionOptions,
ListSessionOptions,
Models,
type UpdateChatSession,
UpdateChatSessionData,
UpdateChatSessionOptions,
} from '../../models';
import { ChatMessageCache } from './message';
import { PromptService } from './prompt';
@ -29,7 +30,6 @@ import {
type ChatSessionForkOptions,
type ChatSessionOptions,
type ChatSessionState,
getTokenEncoder,
type SubmittedMessage,
} from './types';
@ -224,46 +224,12 @@ export class ChatSessionService {
private readonly logger = new Logger(ChatSessionService.name);
constructor(
private readonly db: PrismaClient,
private readonly quota: QuotaService,
private readonly messageCache: ChatMessageCache,
private readonly prompt: PromptService,
private readonly models: Models
) {}
@Transactional()
private async setSession(state: ChatSessionState): Promise<string> {
const session = this.models.copilotSession;
let sessionId = state.sessionId;
// find existing session if session is chat session
if (!state.prompt.action) {
const id = await session.getChatSessionId(state);
if (id) sessionId = id;
}
const haveSession = await session.has(sessionId, state.userId);
if (haveSession) {
// message will only exists when setSession call by session.save
if (state.messages.length) {
await session.setMessages(
sessionId,
state.messages,
this.calculateTokenSize(state.messages, state.prompt.model)
);
}
} else {
await session.create({
...state,
sessionId,
promptName: state.prompt.name,
promptAction: state.prompt.action ?? null,
});
}
return sessionId;
}
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
const session = await this.models.copilotSession.get(sessionId);
if (!session) return;
@ -296,23 +262,6 @@ export class ChatSessionService {
);
}
private calculateTokenSize(messages: PromptMessage[], model: string): number {
const encoder = getTokenEncoder(model);
return messages
.map(m => encoder?.count(m.content) ?? 0)
.reduce((total, length) => total + length, 0);
}
private async countUserMessages(userId: string): Promise<number> {
const sessions = await this.db.aiSession.findMany({
where: { userId },
select: { messageCost: true, prompt: { select: { action: true } } },
});
return sessions
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
.reduce((prev, cost) => prev + cost, 0);
}
async listSessions(
options: ListSessionOptions
): Promise<Omit<ChatSessionState, 'messages'>[]> {
@ -431,7 +380,7 @@ export class ChatSessionService {
limit = quota.copilotActionLimit;
}
const used = await this.countUserMessages(userId);
const used = await this.models.copilotSession.countUserMessages(userId);
return { limit, used };
}
@ -456,20 +405,19 @@ export class ChatSessionService {
}
// validate prompt compatibility with session type
this.models.copilotSession.checkSessionPrompt(
options,
prompt.name,
prompt.action
);
this.models.copilotSession.checkSessionPrompt(options, prompt);
return await this.setSession({
...options,
sessionId,
prompt,
messages: [],
// when client create chat session, we always find root session
parentSessionId: null,
});
return await this.models.copilotSession.createWithPrompt(
{
...options,
sessionId,
prompt,
messages: [],
// when client create chat session, we always find root session
parentSessionId: null,
},
true
);
}
@Transactional()
@ -478,13 +426,16 @@ export class ChatSessionService {
}
@Transactional()
async updateSession(options: UpdateChatSession): Promise<string> {
async update(options: UpdateChatSession): Promise<string> {
const session = await this.getSession(options.sessionId);
if (!session) {
throw new CopilotSessionNotFound();
}
const finalData: UpdateChatSessionData = {};
const finalData: UpdateChatSessionOptions = {
userId: options.userId,
sessionId: options.sessionId,
};
if (options.promptName) {
const prompt = await this.prompt.get(options.promptName);
if (!prompt) {
@ -492,11 +443,7 @@ export class ChatSessionService {
throw new CopilotPromptNotFound({ name: options.promptName });
}
this.models.copilotSession.checkSessionPrompt(
session,
prompt.name,
prompt.action
);
this.models.copilotSession.checkSessionPrompt(session, prompt);
finalData.promptName = prompt.name;
}
finalData.pinned = options.pinned;
@ -508,21 +455,15 @@ export class ChatSessionService {
);
}
return await this.models.copilotSession.update(
options.userId,
options.sessionId,
finalData
);
return await this.models.copilotSession.update(finalData);
}
@Transactional()
async fork(options: ChatSessionForkOptions): Promise<string> {
const state = await this.getSession(options.sessionId);
if (!state) {
throw new CopilotSessionNotFound();
}
if (state.pinned) {
await this.unpin(options.workspaceId, options.userId);
}
let messages = state.messages.map(m => ({ ...m, id: undefined }));
if (options.latestMessageId) {
@ -538,62 +479,17 @@ export class ChatSessionService {
messages = messages.slice(0, lastMessageIdx + 1);
}
const forkedState = {
return await this.models.copilotSession.fork({
...state,
userId: options.userId,
sessionId: randomUUID(),
messages: [],
parentSessionId: options.sessionId,
};
// create session
await this.setSession(forkedState);
// save message
return await this.setSession({ ...forkedState, messages });
messages,
});
}
async cleanup(
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
sessionIds: string[];
}
) {
return await this.db.$transaction(async tx => {
const sessions = await tx.aiSession.findMany({
where: {
id: { in: options.sessionIds },
userId: options.userId,
workspaceId: options.workspaceId,
docId: options.docId,
deletedAt: null,
},
select: { id: true, promptName: true },
});
const sessionIds = sessions.map(({ id }) => id);
// cleanup all messages
await tx.aiSessionMessage.deleteMany({
where: { sessionId: { in: sessionIds } },
});
// only mark action session as deleted
// chat session always can be reuse
const actionIds = (
await Promise.all(
sessions.map(({ id, promptName }) =>
this.prompt
.get(promptName)
.then(prompt => ({ id, action: !!prompt?.action }))
)
)
)
.filter(({ action }) => action)
.map(({ id }) => id);
await tx.aiSession.updateMany({
where: { id: { in: actionIds } },
data: { pinned: false, deletedAt: new Date() },
});
return [...sessionIds, ...actionIds];
});
async cleanup(options: CleanupSessionOptions) {
return await this.models.copilotSession.cleanup(options);
}
async createMessage(message: SubmittedMessage): Promise<string> {
@ -617,7 +513,7 @@ export class ChatSessionService {
const state = await this.getSession(sessionId);
if (state) {
return new ChatSession(this.messageCache, state, async state => {
await this.setSession(state);
await this.models.copilotSession.updateMessages(state);
});
}
return null;

View File

@ -1,8 +1,6 @@
import { type Tokenizer } from '@affine/server-native';
import { z } from 'zod';
import { OneMB } from '../../base';
import { fromModelName } from '../../native';
import type { ChatPrompt } from './prompt';
import { PromptMessageSchema, PureMessageSchema } from './providers';
@ -38,41 +36,6 @@ export const ChatQuerySchema = z
})
);
export enum AvailableModels {
// text to text
Gpt4Omni = 'gpt-4o',
Gpt4Omni0806 = 'gpt-4o-2024-08-06',
Gpt4OmniMini = 'gpt-4o-mini',
Gpt4OmniMini0718 = 'gpt-4o-mini-2024-07-18',
Gpt41 = 'gpt-4.1',
Gpt410414 = 'gpt-4.1-2025-04-14',
Gpt41Mini = 'gpt-4.1-mini',
Gpt41Nano = 'gpt-4.1-nano',
// embeddings
TextEmbedding3Large = 'text-embedding-3-large',
TextEmbedding3Small = 'text-embedding-3-small',
TextEmbeddingAda002 = 'text-embedding-ada-002',
// text to image
DallE3 = 'dall-e-3',
GptImage = 'gpt-image-1',
}
const availableModels = Object.values(AvailableModels);
export function getTokenEncoder(model?: string | null): Tokenizer | null {
if (!model) return null;
if (!availableModels.includes(model as AvailableModels)) return null;
if (model.startsWith('gpt')) {
return fromModelName(model);
} else if (model.startsWith('dall')) {
// dalle don't need to calc the token
return null;
} else {
// c100k based model
return fromModelName('gpt-4');
}
}
// ======== ChatMessage ========
export const ChatMessageSchema = PromptMessageSchema.extend({