Skip to content

Commit

Permalink
Add secure-fraud-detection demo
Browse files Browse the repository at this point in the history
  • Loading branch information
sberyozkin committed May 20, 2024
1 parent 3cd2bd5 commit cf7bfeb
Show file tree
Hide file tree
Showing 18 changed files with 843 additions and 0 deletions.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@
<module>samples/cli-translator</module>
<module>samples/review-triage</module>
<module>samples/fraud-detection</module>
<module>samples/secure-fraud-detection</module>
<module>samples/chatbot</module>
<module>samples/chatbot-easy-rag</module>
<module>samples/csv-chatbot</module>
Expand Down
127 changes: 127 additions & 0 deletions samples/secure-fraud-detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Secure Fraud Detection Demo

This demo showcases the implementation of a simple fraud detection system using LLMs (GPT-4 in this case) for users authenticated with Google.

## The Demo

### Setup

The demo is based on fictional random data generated when the application starts, which includes:

- 3 users
- 50 transactions
- For each transaction, a random amount between 1 and 1000 is generated and assigned to a random user. A random city is also assigned to each transaction.

The setup is defined in the [Setup.java](./src/main/java/io/quarkiverse/langchain4j/samples/Setup.java) class.

The users and transactions are stored in a PostgreSQL database. When running the demo in dev mode (recommended), the database is automatically created and populated.

### Tools

To enable fraud detection, we provide the LLM with access to customer and transaction data through two Panache repositories:

- [CustomerRepository.java](./src/main/java/io/quarkiverse/langchain4j/samples/CustomerRepository.java)
- [TransactionRepository.java](./src/main/java/io/quarkiverse/langchain4j/samples/TransactionRepository.java)

### AI Service

This demo leverages the AI service abstraction, with the interaction between the LLM and the application handled through the AIService interface.

The interface uses specific annotations to define the LLM and the tools to be used:

```java
@RegisterAiService(chatMemorySupplier = AiConfig.MemoryProvider.class,
tools = { TransactionRepository.class })
```

For each message, the prompt is engineered to help the LLM understand the context and answer the request:

```java
@SystemMessage("""
You are a bank account fraud detection AI. You have to detect frauds in transactions.
""")
@UserMessage("""
Your task is to detect whether a fraud was committed for the customer {{customerName}}.
To detect a fraud, perform the following actions:
1. Retrieve the transactions for the customer {{customerName}} with the {{customerEmail}} email address for the last 15 minutes.
2. Sum the amount of all these transactions. Ensure the sum is correct.
3. If the amount is greater than 10000, a fraud is detected.
Answer with a **single** JSON document containing:
- the customer name in the 'customer-name' key
- the computed sum in the 'total' key
- the 'fraud' key set to a boolean value indicating if a fraud was detected
- the 'transactions' key containing the list of transaction amounts
- the 'explanation' key containing an explanation of your answer, especially how the sum is computed.
- if there is a fraud, the 'email' key containing an email to the customer {{customerName}} to warn him about the fraud. The text must be formal and polite. It must ask the customer to contact the bank ASAP.
Your response must be just the JSON document, nothing else.
""")
@Timeout(value = 2, unit = ChronoUnit.MINUTES)
String detectAmountFraudForCustomer(long customerId);
```

_Note:_ You can also use fault tolerance annotations in combination with the prompt annotations.

### Using the AI service

Once defined, you can inject the AI service as a regular bean, and use it:

```java
package io.quarkiverse.langchain4j.sample;

import java.util.List;

import org.eclipse.microprofile.jwt.Claims;
import org.eclipse.microprofile.jwt.JsonWebToken;

import io.quarkus.oidc.IdToken;
import io.quarkus.security.Authenticated;
import jakarta.inject.Inject;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;

@Path("/fraud")
@Authenticated
public class FraudDetectionResource {

private final FraudDetectionAi service;
private final TransactionRepository transactions;

@Inject
@IdToken
private JsonWebToken idToken;

public FraudDetectionResource(FraudDetectionAi service, TransactionRepository transactions) {
this.service = service;
this.transactions = transactions;
}

@GET
@Path("/distance")
public String detectBasedOnDistance() {
return service.detectDistanceFraudForCustomer(idToken.getName(), idToken.getClaim(Claims.email));
}

@GET
@Path("/amount")
public String detectBaseOnAmount() {
return service.detectAmountFraudForCustomer(idToken.getName(), idToken.getClaim(Claims.email));
}
}
```

## Running the Demo

To run the demo, use the following commands:

```shell
mvn quarkus:dev
```
Then, issue requests:

```shell
http ":8080/fraud/amount?customerId=1"
http ":8080/fraud/distance?customerId=1"
```
134 changes: 134 additions & 0 deletions samples/secure-fraud-detection/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-sample-secure-fraud-detection</artifactId>
<name>Quarkus LangChain4j - Sample - Secure Fraud Detection</name>
<version>1.0-SNAPSHOT</version>

<properties>
<compiler-plugin.version>3.13.0</compiler-plugin.version>
<maven.compiler.parameters>true</maven.compiler.parameters>
<maven.compiler.release>17</maven.compiler.release>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<quarkus.platform.artifact-id>quarkus-bom</quarkus.platform.artifact-id>
<quarkus.platform.group-id>io.quarkus</quarkus.platform.group-id>
<quarkus.platform.version>3.8.2</quarkus.platform.version>
<skipITs>true</skipITs>
<surefire-plugin.version>3.2.5</surefire-plugin.version>
<quarkus-langchain4j.version>0.14.0</quarkus-langchain4j.version>
</properties>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>${quarkus.platform.group-id}</groupId>
<artifactId>${quarkus.platform.artifact-id}</artifactId>
<version>${quarkus.platform.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive-jackson</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-oidc</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
<version>${quarkus-langchain4j.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-jdbc-postgresql</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-hibernate-orm-panache</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive-qute</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-maven-plugin</artifactId>
<version>${quarkus.platform.version}</version>
<executions>
<execution>
<goals>
<goal>build</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>${compiler-plugin.version}</version>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.5</version>
<configuration>
<systemPropertyVariables>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</plugin>
</plugins>
</build>

<profiles>
<profile>
<id>native</id>
<activation>
<property>
<name>native</name>
</property>
</activation>
<build>
<plugins>
<plugin>
<artifactId>maven-failsafe-plugin</artifactId>
<version>3.2.5</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
<configuration>
<systemPropertyVariables>
<native.image.path>${project.build.directory}/${project.build.finalName}-runner</native.image.path>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<quarkus.package.type>native</quarkus.package.type>
</properties>
</profile>
</profiles>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.quarkiverse.langchain4j.sample;

import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.Id;

@Entity
public class Customer {

@Id
@GeneratedValue
public Long id;
public String name;
public String email;
public int transactionLimit;
public int distanceLimit;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.quarkiverse.langchain4j.sample;

import dev.langchain4j.agent.tool.Tool;
import io.quarkus.hibernate.orm.panache.PanacheRepository;
import jakarta.enterprise.context.ApplicationScoped;

@ApplicationScoped
public class CustomerRepository implements PanacheRepository<Customer> {

@Tool("Get the transaction limit for a given customer")
public int getTransactionLimit(String customerName, String customerEmail) {
return find("name = ?1 and email = ?2",
customerName,
customerEmail).firstResult().transactionLimit;
}

@Tool("Get the distance limit for a given customer")
public int getDistanceLimit(String customerName, String customerEmail) {
return find("name = ?1 and email = ?2",
customerName,
customerEmail).firstResult().distanceLimit;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package io.quarkiverse.langchain4j.sample;

import java.time.temporal.ChronoUnit;

import org.eclipse.microprofile.faulttolerance.Timeout;

import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService(tools = { TransactionRepository.class, CustomerRepository.class })
public interface FraudDetectionAi {

@SystemMessage("""
You are a bank account fraud detection AI. You have to detect frauds in transactions.
""")
@UserMessage("""
Your task is to detect whether a fraud was committed for the customer {{customerName}}.
To detect a fraud, perform the following actions:
1 - Retrieve the transaction limit for the customer {{customerName}} with the {{customerEmail}} email address.
2 - Retrieve the transactions for the customer {{customerName}} with the {{customerEmail}} email address for the last 15 minutes.
3 - Sum the amount of all of these transactions. Make sure the sum is correct.
4 - If the amount is greater than the transaction limit for this customer, a fraud is detected.
Answer with a **single** JSON document containing:
- the customer name in the 'customer-name' key
- the 'returning-customer' key set to a boolean value indicating if the same query was already issued before
- the transaction limit in the 'transaction-limit' key
- the computed sum in the 'total' key
- the 'fraud' key set to a boolean value indicating if a fraud was detected
- the 'transactions' key containing the list of transaction amounts
- the 'explanation' key containing a explanation of your answer, including how the sum is computed.
- if there is a fraud, the 'email' key containing an email to the customer {{customerName}} to warn about the fraud.
Your response must be just the raw JSON document, without ```json, ``` or anything else. Do not use null JSON properties.
""")
@Timeout(value = 2, unit = ChronoUnit.MINUTES)
String detectAmountFraudForCustomer(String customerName, String customerEmail);

@SystemMessage("""
You are a bank account fraud detection AI. You have to detect frauds in transactions.
""")
@UserMessage("""
Detect frauds based on the distance between two transactions for the customer: {{customerName}}.
To detect a fraud, perform the following actions:
1 - Retrieve the distance limit in kilometers for the customer {{customerName}} with the {{customerEmail}} email address.
2 - Retrieve the transactions for the customer {{customerName}} with the {{customerEmail}} email address for the last 15 minutes.
3 - Retrieve the city for each transaction.
4 - Check if the distance between 2 cities is greater than the distance limit, if so, a fraud is detected.
5 - If a fraud is detected, find the two transactions associated with these cities.
Answer with a **single** JSON document containing:
- the customer name in the 'customer-name' key
- the distance limit in the 'distance-limit' key
- the amount of the first transaction in the 'first-amount' key
- the amount of the second transaction in the 'second-amount' key
- the city of the first transaction in the 'first-city' key
- the city of the second transaction in the 'second-city' key
- the 'fraud' key set to a boolean value indicating if a fraud was detected (so the distance is greater than the distance limit)
- the 'distance' key set to the distance between the two cities
- the 'explanation' key containing a explanation of your answer.
- the 'cities' key containing all the cities for the transactions for the customer {{customerName}} in the last 15 minutes.
- if there is a fraud, the 'email' key containing an email to the customer {{customerName}} to warn about the fraud.
Your response must be just the raw JSON document, without ```json, ``` or anything else. Do not use null JSON properties.
""")
@Timeout(value = 2, unit = ChronoUnit.MINUTES)
String detectDistanceFraudForCustomer(String customerName, String customerEmail);

}
Loading

0 comments on commit cf7bfeb

Please sign in to comment.