Commit ebd75809 authored by jaden's avatar jaden

feat: functions feature

parent 3e321e59
...@@ -3,6 +3,7 @@ import { backendApi } from '.'; ...@@ -3,6 +3,7 @@ import { backendApi } from '.';
import { getModel } from '@/utils/gpt'; import { getModel } from '@/utils/gpt';
import { fetchEventSource, EventStreamContentType } from '@microsoft/fetch-event-source'; import { fetchEventSource, EventStreamContentType } from '@microsoft/fetch-event-source';
import { MutableRefObject, SetStateAction } from 'react'; import { MutableRefObject, SetStateAction } from 'react';
import { functions, get } from 'lodash';
export interface View { export interface View {
type: 'schema' | 'table'; type: 'schema' | 'table';
...@@ -36,6 +37,7 @@ export default class OpenAI { ...@@ -36,6 +37,7 @@ export default class OpenAI {
onFinish: (responseText: string) => any, onFinish: (responseText: string) => any,
onUpdate: (responseText: string, delta: string) => any, onUpdate: (responseText: string, delta: string) => any,
onError: (e: Error) => any, onError: (e: Error) => any,
functions: any[] = [],
closeFn?: MutableRefObject<any>, closeFn?: MutableRefObject<any>,
stream: boolean = true stream: boolean = true
) { ) {
...@@ -46,6 +48,7 @@ export default class OpenAI { ...@@ -46,6 +48,7 @@ export default class OpenAI {
frequency_penalty: model.frequency_penalty, frequency_penalty: model.frequency_penalty,
presence_penalty: model.presence_penalty, presence_penalty: model.presence_penalty,
stream, stream,
functions: functions,
}; };
const controller = new AbortController(); const controller = new AbortController();
const chatPayload = { const chatPayload = {
...@@ -116,8 +119,13 @@ export default class OpenAI { ...@@ -116,8 +119,13 @@ export default class OpenAI {
} }
const text = msg.data; const text = msg.data;
try { try {
console.log(text);
const json = JSON.parse(text); const json = JSON.parse(text);
const delta = json.choices[0].delta.content; const delta = json.choices[0].delta.content;
if (json.choices[0].finish_reason === 'function_call') {
// 🌟 output the Tree structure data
console.log(json);
}
if (delta) { if (delta) {
responseText += delta; responseText += delta;
onUpdate?.(responseText, delta); onUpdate?.(responseText, delta);
...@@ -139,7 +147,9 @@ export default class OpenAI { ...@@ -139,7 +147,9 @@ export default class OpenAI {
const res = await fetch(chatPath, chatPayload); const res = await fetch(chatPath, chatPayload);
clearTimeout(requestTimeoutId); clearTimeout(requestTimeoutId);
const resJson = await res.json(); const resJson = await res.json();
const message = resJson.choices?.at(0)?.message?.content ?? ''; const message =
resJson.choices?.at(0)?.message?.content ??
(JSON.stringify(get(resJson, 'choices[0].message.function_call')) || '');
console.log(message); console.log(message);
onFinish(message); onFinish(message);
} }
......
...@@ -50,6 +50,8 @@ type AIProps = { ...@@ -50,6 +50,8 @@ type AIProps = {
inputProps?: Record<string, any>; inputProps?: Record<string, any>;
SendButton?: ({ inputRef }: { inputRef: any }) => JSX.Element; SendButton?: ({ inputRef }: { inputRef: any }) => JSX.Element;
noHistory?: boolean; noHistory?: boolean;
functions: any[];
stream?: boolean;
}; };
export function AIWrapper({ export function AIWrapper({
...@@ -68,6 +70,8 @@ export function AIWrapper({ ...@@ -68,6 +70,8 @@ export function AIWrapper({
inputProps = {}, inputProps = {},
SendButton, SendButton,
noHistory, noHistory,
functions,
stream,
}: AIProps) { }: AIProps) {
const input = useRef<any>(); const input = useRef<any>();
const scrollContainer = useRef<any>(); const scrollContainer = useRef<any>();
...@@ -164,7 +168,9 @@ export function AIWrapper({ ...@@ -164,7 +168,9 @@ export function AIWrapper({
content: undefined, content: undefined,
}); });
}, },
closeRef functions,
closeRef,
stream || !functions
); );
}, },
[ [
......
export const QUERY_FUNCTION = [
{
name: 'saveExecuteSqlInfo',
description: '存储执行该需求所需要的关键信息',
parameters: {
type: 'object',
properties: {
sql: {
type: 'string',
description: 'The SQL statement to execute',
},
variablesArr: {
type: 'array',
description: 'The array containing variables and their descriptions',
items: {
type: 'object',
properties: {
variable: {
type: 'string',
description: 'The variable',
},
varDescription: {
type: 'string',
description: 'The description of the variable',
},
},
required: ['variable', 'varDescription'],
},
},
queryName: {
type: 'string',
description: 'The name of the query',
},
queryDescription: {
type: 'string',
description: 'The description of the query',
},
},
required: ['sql', 'variablesArr', 'queryName', 'queryDescription'],
},
},
];
...@@ -372,7 +372,7 @@ export function ChatView({ ...@@ -372,7 +372,7 @@ export function ChatView({
const [FunctionOption, setFunctionOption] = useState<ReturnType<typeof Parser.parse>[]>([]); const [FunctionOption, setFunctionOption] = useState<ReturnType<typeof Parser.parse>[]>([]);
return ( return (
<div> <div>
{props.length ? ( {propsRaw.length ? (
<> <>
<div className="flex justify-between items-center mb-[20px]"> <div className="flex justify-between items-center mb-[20px]">
<AI <AI
......
...@@ -44,6 +44,7 @@ import { ...@@ -44,6 +44,7 @@ import {
filter, filter,
find, find,
first, first,
forEach,
get, get,
isEqual, isEqual,
map, map,
...@@ -73,6 +74,8 @@ import { DataTable, QueriesList } from './queriesList'; ...@@ -73,6 +74,8 @@ import { DataTable, QueriesList } from './queriesList';
import { GET_QUERY, GET_SCHEMA_INFO } from '@/data/prompt'; import { GET_QUERY, GET_SCHEMA_INFO } from '@/data/prompt';
import Welcome from 'components/AITool/MessageItem'; import Welcome from 'components/AITool/MessageItem';
import * as queryTipRaw from '@/data/prompt/query-tip'; import * as queryTipRaw from '@/data/prompt/query-tip';
import { QUERY_FUNCTION } from '@/data/prompt/functions';
import { FunctionsJson, functionsJson } from '@/utils/getXMLContent';
const queryTip = map(queryTipRaw, v => { const queryTip = map(queryTipRaw, v => {
return { return {
name: v, name: v,
...@@ -281,6 +284,8 @@ const MessageItemHOC = ({ setShowQueriesList, activeDb, currentModels, activeMod ...@@ -281,6 +284,8 @@ const MessageItemHOC = ({ setShowQueriesList, activeDb, currentModels, activeMod
queryName: any; queryName: any;
queryDescription: any; queryDescription: any;
}>(() => { }>(() => {
const funcJson = new FunctionsJson(message);
const elNode = document.createElement('div'); const elNode = document.createElement('div');
elNode.innerHTML = message; elNode.innerHTML = message;
let sqlNodes = elNode.querySelectorAll('sql'); let sqlNodes = elNode.querySelectorAll('sql');
...@@ -290,27 +295,19 @@ const MessageItemHOC = ({ setShowQueriesList, activeDb, currentModels, activeMod ...@@ -290,27 +295,19 @@ const MessageItemHOC = ({ setShowQueriesList, activeDb, currentModels, activeMod
elNode.innerHTML = str; elNode.innerHTML = str;
sqlNodes = elNode.querySelectorAll('sql'); sqlNodes = elNode.querySelectorAll('sql');
} }
const varNodes = elNode.querySelectorAll('var'); const params = {};
const varDescriptionNodes = elNode.querySelectorAll('varDescription'); forEach(funcJson.get('variablesArr'), (item: any) => {
const queryName = elNode.querySelector('queryName'); const key = get(item, 'variable');
const queryDescription = elNode.querySelector('queryDescription'); if (key) {
set(params, key, get(item, 'varDescription'));
}
});
return { return {
sql: map(sqlNodes, node => node.textContent?.trim()), sql: [funcJson.get('sql').replace(/;\s+/g, ';\n')],
types: map(sqlNodes, node => getQueryType(node.tagName)), types: ['sql'],
params: pickBy( params,
Object.fromEntries( queryName: funcJson.get('queryName'),
map(varNodes, (node, index) => [ queryDescription: funcJson.get('queryDescription'),
node.textContent?.trim(),
varDescriptionNodes[index]?.textContent?.trim() || '',
])
),
(_, key) => {
return /^\$(.*)\$$/.test(key);
}
),
queryName: queryName?.textContent?.trim(),
queryDescription: queryDescription?.textContent?.trim(),
}; };
}, [message]); }, [message]);
const [preViewData, setPreViewData] = useState([]); const [preViewData, setPreViewData] = useState([]);
...@@ -805,6 +802,8 @@ INSERT INTO users (email, name) VALUES ($email$, $name$); ...@@ -805,6 +802,8 @@ INSERT INTO users (email, name) VALUES ($email$, $name$);
> >
<Content> <Content>
<AI <AI
functions={QUERY_FUNCTION}
stream={false}
noHistory noHistory
quickTip={queryTip} quickTip={queryTip}
welcome={ welcome={
......
import { functions, get } from 'lodash';
export class XML { export class XML {
private root: HTMLDivElement; private root: HTMLDivElement;
...@@ -12,3 +13,15 @@ export class XML { ...@@ -12,3 +13,15 @@ export class XML {
return insertText; return insertText;
} }
} }
export class FunctionsJson {
private root: HTMLDivElement;
constructor(code: string) {
this.root = JSON.parse(code);
}
get(code: string): string {
return JSON.parse(get(this.root, 'arguments', '{}'))[code];
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment