parent
697e0bf9ba
commit
e32c9a814a
@ -16,7 +16,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
role: 'assistant',
|
||||
},
|
||||
],
|
||||
tokens: 0,
|
||||
tokens: 8,
|
||||
},
|
||||
]
|
||||
|
||||
@ -30,7 +30,7 @@ Generated by [AVA](https://avajs.dev).
|
||||
role: 'assistant',
|
||||
},
|
||||
],
|
||||
tokens: 0,
|
||||
tokens: 8,
|
||||
},
|
||||
]
|
||||
|
||||
|
Binary file not shown.
@ -347,7 +347,7 @@ test('should be able to update chat session prompt', async t => {
|
||||
t.truthy(sessionId, 'should create session');
|
||||
|
||||
// Update the session
|
||||
const updatedSessionId = await session.updateSession({
|
||||
const updatedSessionId = await session.update({
|
||||
sessionId,
|
||||
promptName: 'Search With AFFiNE AI',
|
||||
userId,
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { AiSession, PrismaClient, User, Workspace } from '@prisma/client';
|
||||
import { PrismaClient, User, Workspace } from '@prisma/client';
|
||||
import ava, { TestFn } from 'ava';
|
||||
import Sinon from 'sinon';
|
||||
|
||||
@ -43,7 +43,7 @@ test.before(async t => {
|
||||
|
||||
let user: User;
|
||||
let workspace: Workspace;
|
||||
let session: AiSession;
|
||||
let sessionId: string;
|
||||
let docId = 'doc1';
|
||||
|
||||
test.beforeEach(async t => {
|
||||
@ -53,7 +53,7 @@ test.beforeEach(async t => {
|
||||
email: 'test@affine.pro',
|
||||
});
|
||||
workspace = await t.context.workspace.create(user.id);
|
||||
session = await t.context.copilotSession.create({
|
||||
sessionId = await t.context.copilotSession.create({
|
||||
sessionId: randomUUID(),
|
||||
workspaceId: workspace.id,
|
||||
docId,
|
||||
@ -68,7 +68,7 @@ test.after(async t => {
|
||||
});
|
||||
|
||||
test('should create a copilot context', async t => {
|
||||
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||
const { id: contextId } = await t.context.copilotContext.create(sessionId);
|
||||
t.truthy(contextId);
|
||||
|
||||
const context = await t.context.copilotContext.get(contextId);
|
||||
@ -77,7 +77,7 @@ test('should create a copilot context', async t => {
|
||||
const config = await t.context.copilotContext.getConfig(contextId);
|
||||
t.is(config?.workspaceId, workspace.id, 'should get context config');
|
||||
|
||||
const context1 = await t.context.copilotContext.getBySessionId(session.id);
|
||||
const context1 = await t.context.copilotContext.getBySessionId(sessionId);
|
||||
t.is(context1?.id, contextId, 'should get context by session id');
|
||||
});
|
||||
|
||||
@ -87,7 +87,7 @@ test('should get null for non-exist job', async t => {
|
||||
});
|
||||
|
||||
test('should update context', async t => {
|
||||
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||
const { id: contextId } = await t.context.copilotContext.create(sessionId);
|
||||
const config = await t.context.copilotContext.getConfig(contextId);
|
||||
|
||||
const doc = {
|
||||
@ -102,7 +102,7 @@ test('should update context', async t => {
|
||||
});
|
||||
|
||||
test('should insert embedding by doc id', async t => {
|
||||
const { id: contextId } = await t.context.copilotContext.create(session.id);
|
||||
const { id: contextId } = await t.context.copilotContext.create(sessionId);
|
||||
|
||||
{
|
||||
await t.context.copilotContext.insertFileEmbedding(contextId, 'file-id', [
|
||||
|
@ -6,7 +6,7 @@ import ava, { ExecutionContext, TestFn } from 'ava';
|
||||
import { CopilotPromptInvalid, CopilotSessionInvalidInput } from '../../base';
|
||||
import {
|
||||
CopilotSessionModel,
|
||||
UpdateChatSessionData,
|
||||
UpdateChatSessionOptions,
|
||||
UserModel,
|
||||
WorkspaceModel,
|
||||
} from '../../models';
|
||||
@ -174,7 +174,10 @@ test('should check session validation for prompts', async t => {
|
||||
sessionTypes.forEach(({ name, session }) => {
|
||||
t.notThrows(
|
||||
() =>
|
||||
copilotSession.checkSessionPrompt(session, 'test-prompt', undefined),
|
||||
copilotSession.checkSessionPrompt(session, {
|
||||
name: 'test-prompt',
|
||||
action: undefined,
|
||||
}),
|
||||
`${name} session should allow non-action prompts`
|
||||
);
|
||||
});
|
||||
@ -195,14 +198,20 @@ test('should check session validation for prompts', async t => {
|
||||
if (shouldThrow) {
|
||||
t.throws(
|
||||
() =>
|
||||
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
|
||||
copilotSession.checkSessionPrompt(session, {
|
||||
name: 'action-prompt',
|
||||
action: 'edit',
|
||||
}),
|
||||
{ instanceOf: CopilotPromptInvalid },
|
||||
`${name} session should reject action prompts`
|
||||
);
|
||||
} else {
|
||||
t.notThrows(
|
||||
() =>
|
||||
copilotSession.checkSessionPrompt(session, 'action-prompt', 'edit'),
|
||||
copilotSession.checkSessionPrompt(session, {
|
||||
name: 'action-prompt',
|
||||
action: 'edit',
|
||||
}),
|
||||
`${name} session should allow action prompts`
|
||||
);
|
||||
}
|
||||
@ -323,14 +332,19 @@ test('should handle session updates and validations', async t => {
|
||||
},
|
||||
});
|
||||
|
||||
type UpdateData = Omit<UpdateChatSessionOptions, 'userId' | 'sessionId'>;
|
||||
const assertUpdateThrows = async (
|
||||
t: ExecutionContext<Context>,
|
||||
sessionId: string,
|
||||
updateData: UpdateChatSessionData,
|
||||
updateData: UpdateData,
|
||||
message: string
|
||||
) => {
|
||||
await t.throwsAsync(
|
||||
t.context.copilotSession.update(user.id, sessionId, updateData),
|
||||
t.context.copilotSession.update({
|
||||
...updateData,
|
||||
userId: user.id,
|
||||
sessionId,
|
||||
}),
|
||||
{ instanceOf: CopilotSessionInvalidInput },
|
||||
message
|
||||
);
|
||||
@ -339,11 +353,15 @@ test('should handle session updates and validations', async t => {
|
||||
const assertUpdate = async (
|
||||
t: ExecutionContext<Context>,
|
||||
sessionId: string,
|
||||
updateData: UpdateChatSessionData,
|
||||
updateData: UpdateData,
|
||||
message: string
|
||||
) => {
|
||||
await t.notThrowsAsync(
|
||||
t.context.copilotSession.update(user.id, sessionId, updateData),
|
||||
t.context.copilotSession.update({
|
||||
...updateData,
|
||||
userId: user.id,
|
||||
sessionId,
|
||||
}),
|
||||
message
|
||||
);
|
||||
};
|
||||
@ -386,7 +404,6 @@ test('should handle session updates and validations', async t => {
|
||||
'forked session should reject docId update'
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// case 3: prompt update validation
|
||||
await assertUpdate(
|
||||
@ -415,14 +432,13 @@ test('should handle session updates and validations', async t => {
|
||||
await createTestSession(t, { sessionId: existingPinnedId, pinned: true });
|
||||
|
||||
// should unpin existing when pinning new session
|
||||
await copilotSession.update(user.id, sessionId, { pinned: true });
|
||||
await copilotSession.update({ userId: user.id, sessionId, pinned: true });
|
||||
|
||||
const sessionStatesAfterPin = await Promise.all([
|
||||
getSessionState(db, sessionId),
|
||||
getSessionState(db, existingPinnedId),
|
||||
]);
|
||||
t.snapshot(
|
||||
sessionStatesAfterPin,
|
||||
[
|
||||
await getSessionState(db, sessionId),
|
||||
await getSessionState(db, existingPinnedId),
|
||||
],
|
||||
'should unpin existing when pinning new session'
|
||||
);
|
||||
}
|
||||
@ -430,11 +446,8 @@ test('should handle session updates and validations', async t => {
|
||||
// test type conversions
|
||||
{
|
||||
const conversionSteps: any[] = [];
|
||||
const convertSession = async (
|
||||
step: string,
|
||||
data: UpdateChatSessionData
|
||||
) => {
|
||||
await copilotSession.update(user.id, sessionId, data);
|
||||
const convertSession = async (step: string, data: UpdateData) => {
|
||||
await copilotSession.update({ ...data, userId: user.id, sessionId });
|
||||
const session = await db.aiSession.findUnique({
|
||||
where: { id: sessionId },
|
||||
select: { docId: true, pinned: true },
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
CopilotSessionInvalidInput,
|
||||
CopilotSessionNotFound,
|
||||
} from '../base';
|
||||
import { getTokenEncoder } from '../native';
|
||||
import { BaseModel } from './base';
|
||||
|
||||
export enum SessionType {
|
||||
@ -17,6 +18,12 @@ export enum SessionType {
|
||||
Doc = 'doc', // docId points to specific document
|
||||
}
|
||||
|
||||
type ChatPrompt = {
|
||||
name: string;
|
||||
action?: string | null;
|
||||
model: string;
|
||||
};
|
||||
|
||||
type ChatAttachment = { attachment: string; mimeType: string } | string;
|
||||
|
||||
type ChatStreamObject = {
|
||||
@ -38,7 +45,7 @@ type ChatMessage = {
|
||||
createdAt: Date;
|
||||
};
|
||||
|
||||
type ChatSession = {
|
||||
type PureChatSession = {
|
||||
sessionId: string;
|
||||
workspaceId: string;
|
||||
docId?: string | null;
|
||||
@ -46,22 +53,44 @@ type ChatSession = {
|
||||
messages?: ChatMessage[];
|
||||
// connect ids
|
||||
userId: string;
|
||||
promptName: string;
|
||||
promptAction: string | null;
|
||||
parentSessionId?: string | null;
|
||||
};
|
||||
|
||||
export type UpdateChatSessionData = Partial<
|
||||
Pick<ChatSession, 'docId' | 'pinned' | 'promptName'>
|
||||
>;
|
||||
export type UpdateChatSession = Pick<ChatSession, 'userId' | 'sessionId'> &
|
||||
UpdateChatSessionData;
|
||||
type ChatSession = PureChatSession & {
|
||||
// connect ids
|
||||
promptName: string;
|
||||
promptAction: string | null;
|
||||
};
|
||||
|
||||
export type ListSessionOptions = {
|
||||
type ChatSessionWithPrompt = PureChatSession & {
|
||||
prompt: ChatPrompt;
|
||||
};
|
||||
|
||||
type ChatSessionBaseState = Pick<ChatSession, 'userId' | 'sessionId'>;
|
||||
|
||||
export type ForkSessionOptions = Omit<
|
||||
ChatSession,
|
||||
'messages' | 'promptName' | 'promptAction'
|
||||
> & {
|
||||
prompt: { name: string; action: string | null | undefined; model: string };
|
||||
messages: ChatMessage[];
|
||||
};
|
||||
|
||||
type UpdateChatSessionMessage = ChatSessionBaseState & {
|
||||
prompt: { model: string };
|
||||
messages: ChatMessage[];
|
||||
};
|
||||
|
||||
export type UpdateChatSessionOptions = ChatSessionBaseState &
|
||||
Pick<Partial<ChatSession>, 'docId' | 'pinned' | 'promptName'>;
|
||||
|
||||
export type UpdateChatSession = ChatSessionBaseState & UpdateChatSessionOptions;
|
||||
|
||||
export type ListSessionOptions = Pick<
|
||||
Partial<ChatSession>,
|
||||
'sessionId' | 'workspaceId' | 'docId' | 'pinned'
|
||||
> & {
|
||||
userId: string;
|
||||
sessionId?: string;
|
||||
workspaceId?: string;
|
||||
docId?: string;
|
||||
action?: boolean;
|
||||
fork?: boolean;
|
||||
limit?: number;
|
||||
@ -74,6 +103,13 @@ export type ListSessionOptions = {
|
||||
withMessages?: boolean;
|
||||
};
|
||||
|
||||
export type CleanupSessionOptions = Pick<
|
||||
ChatSession,
|
||||
'userId' | 'workspaceId' | 'docId'
|
||||
> & {
|
||||
sessionIds: string[];
|
||||
};
|
||||
|
||||
@Injectable()
|
||||
export class CopilotSessionModel extends BaseModel {
|
||||
getSessionType(session: Pick<ChatSession, 'docId' | 'pinned'>): SessionType {
|
||||
@ -84,10 +120,10 @@ export class CopilotSessionModel extends BaseModel {
|
||||
|
||||
checkSessionPrompt(
|
||||
session: Pick<ChatSession, 'docId' | 'pinned'>,
|
||||
promptName: string,
|
||||
promptAction: string | undefined
|
||||
prompt: Partial<ChatPrompt>
|
||||
): boolean {
|
||||
const sessionType = this.getSessionType(session);
|
||||
const { name: promptName, action: promptAction } = prompt;
|
||||
|
||||
// workspace and pinned sessions cannot use action prompts
|
||||
if (
|
||||
@ -110,12 +146,18 @@ export class CopilotSessionModel extends BaseModel {
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async create(state: ChatSession) {
|
||||
async create(state: ChatSession, reuseChat = false): Promise<string> {
|
||||
// find and return existing session if session is chat session
|
||||
if (reuseChat && !state.promptAction) {
|
||||
const sessionId = await this.find(state);
|
||||
if (sessionId) return sessionId;
|
||||
}
|
||||
|
||||
if (state.pinned) {
|
||||
await this.unpin(state.workspaceId, state.userId);
|
||||
}
|
||||
|
||||
const row = await this.db.aiSession.create({
|
||||
const session = await this.db.aiSession.create({
|
||||
data: {
|
||||
id: state.sessionId,
|
||||
workspaceId: state.workspaceId,
|
||||
@ -127,8 +169,46 @@ export class CopilotSessionModel extends BaseModel {
|
||||
promptAction: state.promptAction,
|
||||
parentSessionId: state.parentSessionId,
|
||||
},
|
||||
select: { id: true },
|
||||
});
|
||||
return row;
|
||||
return session.id;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async createWithPrompt(
|
||||
state: ChatSessionWithPrompt,
|
||||
reuseChat = false
|
||||
): Promise<string> {
|
||||
const { prompt, ...rest } = state;
|
||||
return await this.models.copilotSession.create(
|
||||
{ ...rest, promptName: prompt.name, promptAction: prompt.action ?? null },
|
||||
reuseChat
|
||||
);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async fork(options: ForkSessionOptions): Promise<string> {
|
||||
if (!options.messages?.length) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
'Cannot fork session without messages'
|
||||
);
|
||||
}
|
||||
if (options.pinned) {
|
||||
await this.unpin(options.workspaceId, options.userId);
|
||||
}
|
||||
const { messages, ...forkedState } = options;
|
||||
|
||||
// create session
|
||||
const sessionId = await this.createWithPrompt({
|
||||
...forkedState,
|
||||
messages: [],
|
||||
});
|
||||
// save message
|
||||
await this.models.copilotSession.updateMessages({
|
||||
...forkedState,
|
||||
messages,
|
||||
});
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
@ -143,9 +223,7 @@ export class CopilotSessionModel extends BaseModel {
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async getChatSessionId(
|
||||
state: Omit<ChatSession, 'promptName' | 'promptAction'>
|
||||
) {
|
||||
async find(state: PureChatSession) {
|
||||
const extraCondition: Record<string, any> = {};
|
||||
if (state.parentSessionId) {
|
||||
// also check session id if provided session is forked session
|
||||
@ -287,11 +365,8 @@ export class CopilotSessionModel extends BaseModel {
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async update(
|
||||
userId: string,
|
||||
sessionId: string,
|
||||
data: UpdateChatSessionData
|
||||
): Promise<string> {
|
||||
async update(options: UpdateChatSessionOptions): Promise<string> {
|
||||
const { userId, sessionId, docId, promptName, pinned } = options;
|
||||
const session = await this.getExists(
|
||||
sessionId,
|
||||
{
|
||||
@ -313,33 +388,71 @@ export class CopilotSessionModel extends BaseModel {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
`Cannot update action: ${session.id}`
|
||||
);
|
||||
} else if (data.docId && session.parentSessionId) {
|
||||
} else if (docId && session.parentSessionId) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
`Cannot update docId for forked session: ${session.id}`
|
||||
);
|
||||
}
|
||||
|
||||
if (data.promptName) {
|
||||
if (promptName) {
|
||||
const prompt = await this.db.aiPrompt.findFirst({
|
||||
where: { name: data.promptName },
|
||||
where: { name: promptName },
|
||||
});
|
||||
// always not allow to update to action prompt
|
||||
if (!prompt || prompt.action) {
|
||||
throw new CopilotSessionInvalidInput(
|
||||
`Prompt ${data.promptName} not found or not available for session ${sessionId}`
|
||||
`Prompt ${promptName} not found or not available for session ${sessionId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
if (data.pinned && data.pinned !== session.pinned) {
|
||||
if (pinned && pinned !== session.pinned) {
|
||||
// if pin the session, unpin exists session in the workspace
|
||||
await this.unpin(session.workspaceId, userId);
|
||||
}
|
||||
|
||||
await this.db.aiSession.update({ where: { id: sessionId }, data });
|
||||
await this.db.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: { docId, promptName, pinned },
|
||||
});
|
||||
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async cleanup(options: CleanupSessionOptions): Promise<string[]> {
|
||||
const sessions = await this.db.aiSession.findMany({
|
||||
where: {
|
||||
id: { in: options.sessionIds },
|
||||
userId: options.userId,
|
||||
workspaceId: options.workspaceId,
|
||||
docId: options.docId,
|
||||
deletedAt: null,
|
||||
},
|
||||
select: { id: true, prompt: true },
|
||||
});
|
||||
const sessionIds = sessions.map(({ id }) => id);
|
||||
// cleanup all messages
|
||||
await this.db.aiSessionMessage.deleteMany({
|
||||
where: { sessionId: { in: sessionIds } },
|
||||
});
|
||||
|
||||
// only mark action session as deleted
|
||||
// chat session always can be reuse
|
||||
const actionIds = sessions
|
||||
.filter(({ prompt }) => !!prompt.action)
|
||||
.map(({ id }) => id);
|
||||
|
||||
// 标记 action session 为已删除
|
||||
if (actionIds.length > 0) {
|
||||
await this.db.aiSession.updateMany({
|
||||
where: { id: { in: actionIds } },
|
||||
data: { pinned: false, deletedAt: new Date() },
|
||||
});
|
||||
}
|
||||
|
||||
return sessionIds;
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async getMessages(
|
||||
sessionId: string,
|
||||
@ -353,31 +466,42 @@ export class CopilotSessionModel extends BaseModel {
|
||||
});
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async setMessages(
|
||||
sessionId: string,
|
||||
messages: ChatMessage[],
|
||||
tokenCost: number
|
||||
) {
|
||||
await this.db.aiSessionMessage.createMany({
|
||||
data: messages.map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: omit(m.params, ['docs']) || undefined,
|
||||
streamObjects: m.streamObjects || undefined,
|
||||
sessionId,
|
||||
})),
|
||||
});
|
||||
private calculateTokenSize(messages: any[], model: string): number {
|
||||
const encoder = getTokenEncoder(model);
|
||||
const content = messages.map(m => m.content).join('');
|
||||
return encoder?.count(content) || 0;
|
||||
}
|
||||
|
||||
// only count message generated by user
|
||||
const userMessages = messages.filter(m => m.role === 'user');
|
||||
await this.db.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: {
|
||||
messageCost: { increment: userMessages.length },
|
||||
tokenCost: { increment: tokenCost },
|
||||
},
|
||||
});
|
||||
@Transactional()
|
||||
async updateMessages(state: UpdateChatSessionMessage) {
|
||||
const { sessionId, userId, messages } = state;
|
||||
const haveSession = await this.has(sessionId, userId);
|
||||
if (!haveSession) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
if (messages.length) {
|
||||
const tokenCost = this.calculateTokenSize(messages, state.prompt.model);
|
||||
await this.db.aiSessionMessage.createMany({
|
||||
data: messages.map(m => ({
|
||||
...m,
|
||||
attachments: m.attachments || undefined,
|
||||
params: omit(m.params, ['docs']) || undefined,
|
||||
streamObjects: m.streamObjects || undefined,
|
||||
sessionId,
|
||||
})),
|
||||
});
|
||||
|
||||
// only count message generated by user
|
||||
const userMessages = messages.filter(m => m.role === 'user');
|
||||
await this.db.aiSession.update({
|
||||
where: { id: sessionId },
|
||||
data: {
|
||||
messageCost: { increment: userMessages.length },
|
||||
tokenCost: { increment: tokenCost },
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
@ -404,4 +528,15 @@ export class CopilotSessionModel extends BaseModel {
|
||||
await this.db.aiSessionMessage.deleteMany({ where: { id: { in: ids } } });
|
||||
}
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async countUserMessages(userId: string): Promise<number> {
|
||||
const sessions = await this.db.aiSession.findMany({
|
||||
where: { userId },
|
||||
select: { messageCost: true, prompt: { select: { action: true } } },
|
||||
});
|
||||
return sessions
|
||||
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
|
||||
.reduce((prev, cost) => prev + cost, 0);
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
import serverNativeModule from '@affine/server-native';
|
||||
import serverNativeModule, { type Tokenizer } from '@affine/server-native';
|
||||
|
||||
export const mergeUpdatesInApplyWay = serverNativeModule.mergeUpdatesInApplyWay;
|
||||
|
||||
@ -16,10 +16,21 @@ export const mintChallengeResponse = async (resource: string, bits: number) => {
|
||||
return serverNativeModule.mintChallengeResponse(resource, bits);
|
||||
};
|
||||
|
||||
export function getTokenEncoder(model?: string | null): Tokenizer | null {
|
||||
if (!model) return null;
|
||||
if (model.startsWith('gpt')) {
|
||||
return serverNativeModule.fromModelName(model);
|
||||
} else if (model.startsWith('dall')) {
|
||||
// dalle don't need to calc the token
|
||||
return null;
|
||||
} else {
|
||||
// c100k based model
|
||||
return serverNativeModule.fromModelName('gpt-4');
|
||||
}
|
||||
}
|
||||
|
||||
export const getMime = serverNativeModule.getMime;
|
||||
export const parseDoc = serverNativeModule.parseDoc;
|
||||
export const Tokenizer = serverNativeModule.Tokenizer;
|
||||
export const fromModelName = serverNativeModule.fromModelName;
|
||||
export const htmlSanitize = serverNativeModule.htmlSanitize;
|
||||
export const AFFINE_PRO_PUBLIC_KEY = serverNativeModule.AFFINE_PRO_PUBLIC_KEY;
|
||||
export const AFFINE_PRO_LICENSE_AES_KEY =
|
||||
|
@ -3,8 +3,8 @@ import { Logger } from '@nestjs/common';
|
||||
import { AiPrompt } from '@prisma/client';
|
||||
import Mustache from 'mustache';
|
||||
|
||||
import { getTokenEncoder } from '../../../native';
|
||||
import { PromptConfig, PromptMessage, PromptParams } from '../providers';
|
||||
import { getTokenEncoder } from '../types';
|
||||
|
||||
// disable escaping
|
||||
Mustache.escape = (text: string) => text;
|
||||
@ -56,8 +56,7 @@ export class ChatPrompt {
|
||||
private readonly messages: PromptMessage[]
|
||||
) {
|
||||
this.encoder = getTokenEncoder(model);
|
||||
this.promptTokenSize =
|
||||
this.encoder?.count(messages.map(m => m.content).join('') || '') || 0;
|
||||
this.promptTokenSize = this.encode(messages.map(m => m.content).join(''));
|
||||
this.templateParamKeys = extractMustacheParams(
|
||||
messages.map(m => m.content).join('')
|
||||
);
|
||||
|
@ -39,15 +39,12 @@ import { PromptMessage, StreamObject } from './providers';
|
||||
import { ChatSessionService } from './session';
|
||||
import { CopilotStorage } from './storage';
|
||||
import {
|
||||
AvailableModels,
|
||||
type ChatHistory,
|
||||
type ChatMessage,
|
||||
type ChatSessionState,
|
||||
SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
registerEnumType(AvailableModels, { name: 'CopilotModel' });
|
||||
|
||||
export const COPILOT_LOCKER = 'copilot';
|
||||
|
||||
// ================== Input Types ==================
|
||||
@ -301,8 +298,6 @@ class CopilotPromptMessageType {
|
||||
params!: Record<string, string> | null;
|
||||
}
|
||||
|
||||
registerEnumType(AvailableModels, { name: 'CopilotModels' });
|
||||
|
||||
@ObjectType()
|
||||
class CopilotPromptType {
|
||||
@Field(() => String)
|
||||
@ -533,7 +528,7 @@ export class CopilotResolver {
|
||||
}
|
||||
|
||||
await this.chatSession.checkQuota(user.id);
|
||||
return await this.chatSession.updateSession({
|
||||
return await this.chatSession.update({
|
||||
...options,
|
||||
userId: user.id,
|
||||
});
|
||||
@ -682,8 +677,8 @@ class CreateCopilotPromptInput {
|
||||
@Field(() => String)
|
||||
name!: string;
|
||||
|
||||
@Field(() => AvailableModels)
|
||||
model!: AvailableModels;
|
||||
@Field(() => String)
|
||||
model!: string;
|
||||
|
||||
@Field(() => String, { nullable: true })
|
||||
action!: string | null;
|
||||
|
@ -2,7 +2,7 @@ import { randomUUID } from 'node:crypto';
|
||||
|
||||
import { Injectable, Logger } from '@nestjs/common';
|
||||
import { Transactional } from '@nestjs-cls/transactional';
|
||||
import { AiPromptRole, PrismaClient } from '@prisma/client';
|
||||
import { AiPromptRole } from '@prisma/client';
|
||||
|
||||
import {
|
||||
CopilotActionTaken,
|
||||
@ -14,10 +14,11 @@ import {
|
||||
} from '../../base';
|
||||
import { QuotaService } from '../../core/quota';
|
||||
import {
|
||||
CleanupSessionOptions,
|
||||
ListSessionOptions,
|
||||
Models,
|
||||
type UpdateChatSession,
|
||||
UpdateChatSessionData,
|
||||
UpdateChatSessionOptions,
|
||||
} from '../../models';
|
||||
import { ChatMessageCache } from './message';
|
||||
import { PromptService } from './prompt';
|
||||
@ -29,7 +30,6 @@ import {
|
||||
type ChatSessionForkOptions,
|
||||
type ChatSessionOptions,
|
||||
type ChatSessionState,
|
||||
getTokenEncoder,
|
||||
type SubmittedMessage,
|
||||
} from './types';
|
||||
|
||||
@ -224,46 +224,12 @@ export class ChatSessionService {
|
||||
private readonly logger = new Logger(ChatSessionService.name);
|
||||
|
||||
constructor(
|
||||
private readonly db: PrismaClient,
|
||||
private readonly quota: QuotaService,
|
||||
private readonly messageCache: ChatMessageCache,
|
||||
private readonly prompt: PromptService,
|
||||
private readonly models: Models
|
||||
) {}
|
||||
|
||||
@Transactional()
|
||||
private async setSession(state: ChatSessionState): Promise<string> {
|
||||
const session = this.models.copilotSession;
|
||||
let sessionId = state.sessionId;
|
||||
|
||||
// find existing session if session is chat session
|
||||
if (!state.prompt.action) {
|
||||
const id = await session.getChatSessionId(state);
|
||||
if (id) sessionId = id;
|
||||
}
|
||||
|
||||
const haveSession = await session.has(sessionId, state.userId);
|
||||
if (haveSession) {
|
||||
// message will only exists when setSession call by session.save
|
||||
if (state.messages.length) {
|
||||
await session.setMessages(
|
||||
sessionId,
|
||||
state.messages,
|
||||
this.calculateTokenSize(state.messages, state.prompt.model)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
await session.create({
|
||||
...state,
|
||||
sessionId,
|
||||
promptName: state.prompt.name,
|
||||
promptAction: state.prompt.action ?? null,
|
||||
});
|
||||
}
|
||||
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
async getSession(sessionId: string): Promise<ChatSessionState | undefined> {
|
||||
const session = await this.models.copilotSession.get(sessionId);
|
||||
if (!session) return;
|
||||
@ -296,23 +262,6 @@ export class ChatSessionService {
|
||||
);
|
||||
}
|
||||
|
||||
private calculateTokenSize(messages: PromptMessage[], model: string): number {
|
||||
const encoder = getTokenEncoder(model);
|
||||
return messages
|
||||
.map(m => encoder?.count(m.content) ?? 0)
|
||||
.reduce((total, length) => total + length, 0);
|
||||
}
|
||||
|
||||
private async countUserMessages(userId: string): Promise<number> {
|
||||
const sessions = await this.db.aiSession.findMany({
|
||||
where: { userId },
|
||||
select: { messageCost: true, prompt: { select: { action: true } } },
|
||||
});
|
||||
return sessions
|
||||
.map(({ messageCost, prompt: { action } }) => (action ? 1 : messageCost))
|
||||
.reduce((prev, cost) => prev + cost, 0);
|
||||
}
|
||||
|
||||
async listSessions(
|
||||
options: ListSessionOptions
|
||||
): Promise<Omit<ChatSessionState, 'messages'>[]> {
|
||||
@ -431,7 +380,7 @@ export class ChatSessionService {
|
||||
limit = quota.copilotActionLimit;
|
||||
}
|
||||
|
||||
const used = await this.countUserMessages(userId);
|
||||
const used = await this.models.copilotSession.countUserMessages(userId);
|
||||
|
||||
return { limit, used };
|
||||
}
|
||||
@ -456,20 +405,19 @@ export class ChatSessionService {
|
||||
}
|
||||
|
||||
// validate prompt compatibility with session type
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
options,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
this.models.copilotSession.checkSessionPrompt(options, prompt);
|
||||
|
||||
return await this.setSession({
|
||||
...options,
|
||||
sessionId,
|
||||
prompt,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
});
|
||||
return await this.models.copilotSession.createWithPrompt(
|
||||
{
|
||||
...options,
|
||||
sessionId,
|
||||
prompt,
|
||||
messages: [],
|
||||
// when client create chat session, we always find root session
|
||||
parentSessionId: null,
|
||||
},
|
||||
true
|
||||
);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
@ -478,13 +426,16 @@ export class ChatSessionService {
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async updateSession(options: UpdateChatSession): Promise<string> {
|
||||
async update(options: UpdateChatSession): Promise<string> {
|
||||
const session = await this.getSession(options.sessionId);
|
||||
if (!session) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
|
||||
const finalData: UpdateChatSessionData = {};
|
||||
const finalData: UpdateChatSessionOptions = {
|
||||
userId: options.userId,
|
||||
sessionId: options.sessionId,
|
||||
};
|
||||
if (options.promptName) {
|
||||
const prompt = await this.prompt.get(options.promptName);
|
||||
if (!prompt) {
|
||||
@ -492,11 +443,7 @@ export class ChatSessionService {
|
||||
throw new CopilotPromptNotFound({ name: options.promptName });
|
||||
}
|
||||
|
||||
this.models.copilotSession.checkSessionPrompt(
|
||||
session,
|
||||
prompt.name,
|
||||
prompt.action
|
||||
);
|
||||
this.models.copilotSession.checkSessionPrompt(session, prompt);
|
||||
finalData.promptName = prompt.name;
|
||||
}
|
||||
finalData.pinned = options.pinned;
|
||||
@ -508,21 +455,15 @@ export class ChatSessionService {
|
||||
);
|
||||
}
|
||||
|
||||
return await this.models.copilotSession.update(
|
||||
options.userId,
|
||||
options.sessionId,
|
||||
finalData
|
||||
);
|
||||
return await this.models.copilotSession.update(finalData);
|
||||
}
|
||||
|
||||
@Transactional()
|
||||
async fork(options: ChatSessionForkOptions): Promise<string> {
|
||||
const state = await this.getSession(options.sessionId);
|
||||
if (!state) {
|
||||
throw new CopilotSessionNotFound();
|
||||
}
|
||||
if (state.pinned) {
|
||||
await this.unpin(options.workspaceId, options.userId);
|
||||
}
|
||||
|
||||
let messages = state.messages.map(m => ({ ...m, id: undefined }));
|
||||
if (options.latestMessageId) {
|
||||
@ -538,62 +479,17 @@ export class ChatSessionService {
|
||||
messages = messages.slice(0, lastMessageIdx + 1);
|
||||
}
|
||||
|
||||
const forkedState = {
|
||||
return await this.models.copilotSession.fork({
|
||||
...state,
|
||||
userId: options.userId,
|
||||
sessionId: randomUUID(),
|
||||
messages: [],
|
||||
parentSessionId: options.sessionId,
|
||||
};
|
||||
// create session
|
||||
await this.setSession(forkedState);
|
||||
// save message
|
||||
return await this.setSession({ ...forkedState, messages });
|
||||
messages,
|
||||
});
|
||||
}
|
||||
|
||||
async cleanup(
|
||||
options: Omit<ChatSessionOptions, 'pinned' | 'promptName'> & {
|
||||
sessionIds: string[];
|
||||
}
|
||||
) {
|
||||
return await this.db.$transaction(async tx => {
|
||||
const sessions = await tx.aiSession.findMany({
|
||||
where: {
|
||||
id: { in: options.sessionIds },
|
||||
userId: options.userId,
|
||||
workspaceId: options.workspaceId,
|
||||
docId: options.docId,
|
||||
deletedAt: null,
|
||||
},
|
||||
select: { id: true, promptName: true },
|
||||
});
|
||||
const sessionIds = sessions.map(({ id }) => id);
|
||||
// cleanup all messages
|
||||
await tx.aiSessionMessage.deleteMany({
|
||||
where: { sessionId: { in: sessionIds } },
|
||||
});
|
||||
|
||||
// only mark action session as deleted
|
||||
// chat session always can be reuse
|
||||
const actionIds = (
|
||||
await Promise.all(
|
||||
sessions.map(({ id, promptName }) =>
|
||||
this.prompt
|
||||
.get(promptName)
|
||||
.then(prompt => ({ id, action: !!prompt?.action }))
|
||||
)
|
||||
)
|
||||
)
|
||||
.filter(({ action }) => action)
|
||||
.map(({ id }) => id);
|
||||
|
||||
await tx.aiSession.updateMany({
|
||||
where: { id: { in: actionIds } },
|
||||
data: { pinned: false, deletedAt: new Date() },
|
||||
});
|
||||
|
||||
return [...sessionIds, ...actionIds];
|
||||
});
|
||||
async cleanup(options: CleanupSessionOptions) {
|
||||
return await this.models.copilotSession.cleanup(options);
|
||||
}
|
||||
|
||||
async createMessage(message: SubmittedMessage): Promise<string> {
|
||||
@ -617,7 +513,7 @@ export class ChatSessionService {
|
||||
const state = await this.getSession(sessionId);
|
||||
if (state) {
|
||||
return new ChatSession(this.messageCache, state, async state => {
|
||||
await this.setSession(state);
|
||||
await this.models.copilotSession.updateMessages(state);
|
||||
});
|
||||
}
|
||||
return null;
|
||||
|
@ -1,8 +1,6 @@
|
||||
import { type Tokenizer } from '@affine/server-native';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { OneMB } from '../../base';
|
||||
import { fromModelName } from '../../native';
|
||||
import type { ChatPrompt } from './prompt';
|
||||
import { PromptMessageSchema, PureMessageSchema } from './providers';
|
||||
|
||||
@ -38,41 +36,6 @@ export const ChatQuerySchema = z
|
||||
})
|
||||
);
|
||||
|
||||
export enum AvailableModels {
|
||||
// text to text
|
||||
Gpt4Omni = 'gpt-4o',
|
||||
Gpt4Omni0806 = 'gpt-4o-2024-08-06',
|
||||
Gpt4OmniMini = 'gpt-4o-mini',
|
||||
Gpt4OmniMini0718 = 'gpt-4o-mini-2024-07-18',
|
||||
Gpt41 = 'gpt-4.1',
|
||||
Gpt410414 = 'gpt-4.1-2025-04-14',
|
||||
Gpt41Mini = 'gpt-4.1-mini',
|
||||
Gpt41Nano = 'gpt-4.1-nano',
|
||||
// embeddings
|
||||
TextEmbedding3Large = 'text-embedding-3-large',
|
||||
TextEmbedding3Small = 'text-embedding-3-small',
|
||||
TextEmbeddingAda002 = 'text-embedding-ada-002',
|
||||
// text to image
|
||||
DallE3 = 'dall-e-3',
|
||||
GptImage = 'gpt-image-1',
|
||||
}
|
||||
|
||||
const availableModels = Object.values(AvailableModels);
|
||||
|
||||
export function getTokenEncoder(model?: string | null): Tokenizer | null {
|
||||
if (!model) return null;
|
||||
if (!availableModels.includes(model as AvailableModels)) return null;
|
||||
if (model.startsWith('gpt')) {
|
||||
return fromModelName(model);
|
||||
} else if (model.startsWith('dall')) {
|
||||
// dalle don't need to calc the token
|
||||
return null;
|
||||
} else {
|
||||
// c100k based model
|
||||
return fromModelName('gpt-4');
|
||||
}
|
||||
}
|
||||
|
||||
// ======== ChatMessage ========
|
||||
|
||||
export const ChatMessageSchema = PromptMessageSchema.extend({
|
||||
|
Loading…
x
Reference in New Issue
Block a user