Skip to content

Commit

Permalink
feat(Postgres Chat Memory, Redis Chat Memory, Xata): Add support for …
Browse files Browse the repository at this point in the history
…context window length (#10203)
  • Loading branch information
jeanpaul authored Aug 6, 2024
1 parent 1eba7c3 commit e3edeaa
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import type { BufferWindowMemoryInput } from 'langchain/memory';
import { BufferWindowMemory } from 'langchain/memory';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

class MemoryChatBufferSingleton {
Expand Down Expand Up @@ -130,13 +130,7 @@ export class MemoryBufferWindow implements INodeType {
},
},
sessionKeyProperty,
{
displayName: 'Context Window Length',
name: 'contextWindowLength',
type: 'number',
default: 5,
description: 'The number of previous messages to consider for context',
},
contextWindowLengthProperty,
],
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
/* eslint-disable n8n-nodes-base/node-dirname-against-convention */
import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow';
import { NodeConnectionType } from 'n8n-workflow';
import { BufferMemory } from 'langchain/memory';
import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import { PostgresChatMessageHistory } from '@langchain/community/stores/message/postgres';
import type pg from 'pg';
import { configurePostgres } from 'n8n-nodes-base/dist/nodes/Postgres/v2/transport';
import type { PostgresNodeCredentials } from 'n8n-nodes-base/dist/nodes/Postgres/v2/helpers/interfaces';
import { postgresConnectionTest } from 'n8n-nodes-base/dist/nodes/Postgres/v2/methods/credentialTest';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

export class MemoryPostgresChat implements INodeType {
Expand All @@ -18,7 +18,7 @@ export class MemoryPostgresChat implements INodeType {
name: 'memoryPostgresChat',
icon: 'file:postgres.svg',
group: ['transform'],
version: [1],
version: [1, 1.1],
description: 'Stores the chat history in Postgres table.',
defaults: {
name: 'Postgres Chat Memory',
Expand Down Expand Up @@ -60,6 +60,10 @@ export class MemoryPostgresChat implements INodeType {
description:
'The table name to store the chat history in. If table does not exist, it will be created.',
},
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.1 } }] } },
},
],
};

Expand All @@ -83,12 +87,19 @@ export class MemoryPostgresChat implements INodeType {
tableName,
});

const memory = new BufferMemory({
const memClass = this.getNode().typeVersion < 1.1 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.1
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };

const memory = new memClass({
memoryKey: 'chat_history',
chatHistory: pgChatHistory,
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
...kOptions,
});

async function closeFunction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import {
type SupplyData,
NodeConnectionType,
} from 'n8n-workflow';
import { BufferMemory } from 'langchain/memory';
import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import type { RedisChatMessageHistoryInput } from '@langchain/redis';
import { RedisChatMessageHistory } from '@langchain/redis';
import type { RedisClientOptions } from 'redis';
import { createClient } from 'redis';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

export class MemoryRedisChat implements INodeType {
Expand All @@ -23,7 +23,7 @@ export class MemoryRedisChat implements INodeType {
name: 'memoryRedisChat',
icon: 'file:redis.svg',
group: ['transform'],
version: [1, 1.1, 1.2],
version: [1, 1.1, 1.2, 1.3],
description: 'Stores the chat history in Redis.',
defaults: {
name: 'Redis Chat Memory',
Expand Down Expand Up @@ -95,6 +95,10 @@ export class MemoryRedisChat implements INodeType {
description:
'For how long the session should be stored in seconds. If set to 0 it will not expire.',
},
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } },
},
],
};

Expand Down Expand Up @@ -143,12 +147,19 @@ export class MemoryRedisChat implements INodeType {
}
const redisChatHistory = new RedisChatMessageHistory(redisChatConfig);

const memory = new BufferMemory({
const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.3
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };

const memory = new memClass({
memoryKey: 'chat_history',
chatHistory: redisChatHistory,
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
...kOptions,
});

async function closeFunction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import { NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import type { IExecuteFunctions, INodeType, INodeTypeDescription, SupplyData } from 'n8n-workflow';
import { XataChatMessageHistory } from '@langchain/community/stores/message/xata';
import { BufferMemory } from 'langchain/memory';
import { BufferMemory, BufferWindowMemory } from 'langchain/memory';
import { BaseClient } from '@xata.io/client';
import { logWrapper } from '../../../utils/logWrapper';
import { getConnectionHintNoticeField } from '../../../utils/sharedFields';
import { sessionIdOption, sessionKeyProperty } from '../descriptions';
import { sessionIdOption, sessionKeyProperty, contextWindowLengthProperty } from '../descriptions';
import { getSessionId } from '../../../utils/helpers';

export class MemoryXata implements INodeType {
Expand All @@ -15,7 +15,7 @@ export class MemoryXata implements INodeType {
name: 'memoryXata',
icon: 'file:xata.svg',
group: ['transform'],
version: [1, 1.1, 1.2],
version: [1, 1.1, 1.2, 1.3],
description: 'Use Xata Memory',
defaults: {
name: 'Xata',
Expand Down Expand Up @@ -81,6 +81,10 @@ export class MemoryXata implements INodeType {
},
},
sessionKeyProperty,
{
...contextWindowLengthProperty,
displayOptions: { hide: { '@version': [{ _cnd: { lt: 1.3 } }] } },
},
],
};

Expand Down Expand Up @@ -120,12 +124,19 @@ export class MemoryXata implements INodeType {
apiKey: credentials.apiKey as string,
});

const memory = new BufferMemory({
const memClass = this.getNode().typeVersion < 1.3 ? BufferMemory : BufferWindowMemory;
const kOptions =
this.getNode().typeVersion < 1.3
? {}
: { k: this.getNodeParameter('contextWindowLength', itemIndex) };

const memory = new memClass({
chatHistory,
memoryKey: 'chat_history',
returnMessages: true,
inputKey: 'input',
outputKey: 'output',
...kOptions,
});

return {
Expand Down
8 changes: 8 additions & 0 deletions packages/@n8n/nodes-langchain/nodes/memory/descriptions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ export const sessionKeyProperty: INodeProperties = {
},
},
};

export const contextWindowLengthProperty: INodeProperties = {
displayName: 'Context Window Length',
name: 'contextWindowLength',
type: 'number',
default: 5,
hint: 'How many past interactions the model receives as context',
};

0 comments on commit e3edeaa

Please sign in to comment.