Skip to content

Commit

Permalink
Add secure-fraud-detection demo
Browse files Browse the repository at this point in the history
  • Loading branch information
sberyozkin committed Jun 13, 2024
1 parent 470a775 commit ee23108
Show file tree
Hide file tree
Showing 25 changed files with 903 additions and 0 deletions.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
<module>samples/cli-translator</module>
<module>samples/review-triage</module>
<module>samples/fraud-detection</module>
<module>samples/secure-fraud-detection</module>
<module>samples/chatbot</module>
<module>samples/chatbot-easy-rag</module>
<module>samples/sql-chatbot</module>
Expand Down
118 changes: 118 additions & 0 deletions samples/secure-fraud-detection/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Secure Fraud Detection Demo

This demo showcases the implementation of a secure fraud detection system which is available only to users authenticated with Google.
It uses the `gpt-3.5-turbo` LLM, use `quarkus.langchain4j.openai.chat-model.model-name` property to select a different model.

## The Demo

### Setup

The demo requires that your Google account's name and email are configured.
You can use system or env properties, see `Running the Demo` section below.

When the application starts, 5 transactions with random amounts between 1 and 1000 are generated for the registered user.
A random city is also assigned to each transaction.

The setup is defined in the [Setup.java](./src/main/java/io/quarkiverse/langchain4j/samples/Setup.java) class.

The registered user and transactions are stored in a PostgreSQL database. When running the demo in dev mode (recommended), the database is automatically created and populated.

### Content Retrieval

To enable fraud detection, we provide the LLM with access to the custom [FraudDetectionContentRetriever](./src/main/java/io/quarkiverse/langchain4j/samples/FraudDetectionContentRetriever.java) content retriever.

`FraudDetectionContentRetriever` is registered by [FraudDetectionRetrievalAugmentor](./src/main/java/io/quarkiverse/langchain4j/samples/FraudDetectionRetrievalAugmentor.java).

It retrieves transaction data for the currently authenticated user through two Panache repositories:

- [CustomerRepository.java](./src/main/java/io/quarkiverse/langchain4j/samples/CustomerRepository.java)
- [TransactionRepository.java](./src/main/java/io/quarkiverse/langchain4j/samples/TransactionRepository.java)

It extracts the authenticated user's name and email from a custom memory id string representation. Memory id is created by [SecureMemoryIdProvider](./src/main/java/io/quarkiverse/langchain4j/samples/SecureMemoryIdProvider.java) from the authenticated security identity. `SecureMemoryIdProvider` is registered as a Java service provider.

Currently, the memory id has the following format: `userName:userEmail#suffix` with an AI service specific `#suffix` added by the extension runtime in order to correctly segregate memory of concurrent requests to different AI services.

### AI Service

This demo leverages the AI service abstraction, with the interaction between the LLM and the application handled through the AIService interface.

The `io.quarkiverse.langchain4j.sample.FraudDetectionAi` interface uses specific annotations to define the LLM:

```java
@RegisterAiService(retrievalAugmentor = FraudDetectionRetrievalAugmentor.class)
```

For each message, the prompt is engineered to help the LLM understand the context and answer the request:

```java
@SystemMessage("""
You are a bank account fraud detection AI. You have to detect frauds in transactions.
""")
@UserMessage("""
Your task is to detect whether a fraud was committed for the customer.
Answer with a **single** JSON document containing:
- the customer name in the 'customer-name' key
- the transaction limit in the 'transaction-limit' key
- the computed sum of all transactions committed during the last 15 minutes in the 'total' key
- the 'fraud' key set to true if the computed sum of all transactions is greater than the transaction limit
- the 'transactions' key containing an array of JSON objects. Each object must have transaction 'amount', 'city' and formatted 'time' keys.
- the 'explanation' key containing an explanation of your answer.
- the 'email' key containing the customer email if the fraud was detected.
Your response must be just the raw JSON document, without ```json, ``` or anything else. Do not use null JSON properties.
""")
@Timeout(value = 2, unit = ChronoUnit.MINUTES)
String detectAmountFraudForCustomer();
```

_Note:_ You can also use fault tolerance annotations in combination with the prompt annotations.

### Using the AI service

Once defined, you can inject the AI service as a regular bean, and use it:

```java
package io.quarkiverse.langchain4j.sample;

import io.quarkus.security.Authenticated;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;

@Path("/fraud")
@Authenticated
public class FraudDetectionResource {

private final FraudDetectionAi service;

public FraudDetectionResource(FraudDetectionAi service) {
this.service = service;
}

@GET
@Path("/amount")
public String detectBaseOnAmount() {
return service.detectAmountFraudForCustomer();
}
}
```

`FraudDetectionResource` can only be accessed by authenticated users.

## Google Authentication

This demo requires users to authenticate with Google.
All you need to do is to register an application with Google, follow steps listed in the [Quarkus Google](https://quarkus.io/guides/security-openid-connect-providers#google) section.
Name your Google application as `Quarkus LangChain4j AI`, and make sure an allowed callback URL is set to `http://localhost:8080/login`.
Google will generate a client id and secret, use them to set `quarkus.oidc.client-id` and `quarkus.oidc.credentials.secret` properties.

## Running the Demo

To run the demo, use the following command:

```shell
mvn quarkus:dev -Dname="Firstname Familyname" [email protected]
```

Then, access `http://localhost:8080`, login to Google, and follow a provided application link to check the fraud.

134 changes: 134 additions & 0 deletions samples/secure-fraud-detection/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-sample-secure-fraud-detection</artifactId>
<name>Quarkus LangChain4j - Sample - Secure Fraud Detection</name>
<version>1.0-SNAPSHOT</version>

<properties>
<compiler-plugin.version>3.13.0</compiler-plugin.version>
<maven.compiler.parameters>true</maven.compiler.parameters>
<maven.compiler.release>17</maven.compiler.release>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<quarkus.platform.artifact-id>quarkus-bom</quarkus.platform.artifact-id>
<quarkus.platform.group-id>io.quarkus</quarkus.platform.group-id>
<quarkus.platform.version>3.9.4</quarkus.platform.version>
<skipITs>true</skipITs>
<surefire-plugin.version>3.2.5</surefire-plugin.version>
<quarkus-langchain4j.version>0.15.1</quarkus-langchain4j.version>
</properties>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>${quarkus.platform.group-id}</groupId>
<artifactId>${quarkus.platform.artifact-id}</artifactId>
<version>${quarkus.platform.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive-jackson</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-oidc</artifactId>
</dependency>
<dependency>
<groupId>io.quarkiverse.langchain4j</groupId>
<artifactId>quarkus-langchain4j-openai</artifactId>
<version>${quarkus-langchain4j.version}</version>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-smallrye-fault-tolerance</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-jdbc-postgresql</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-hibernate-orm-panache</artifactId>
</dependency>
<dependency>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-resteasy-reactive-qute</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>io.quarkus</groupId>
<artifactId>quarkus-maven-plugin</artifactId>
<version>${quarkus.platform.version}</version>
<executions>
<execution>
<goals>
<goal>build</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>${compiler-plugin.version}</version>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.2.5</version>
<configuration>
<systemPropertyVariables>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</plugin>
</plugins>
</build>

<profiles>
<profile>
<id>native</id>
<activation>
<property>
<name>native</name>
</property>
</activation>
<build>
<plugins>
<plugin>
<artifactId>maven-failsafe-plugin</artifactId>
<version>3.2.5</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
<configuration>
<systemPropertyVariables>
<native.image.path>${project.build.directory}/${project.build.finalName}-runner</native.image.path>
<java.util.logging.manager>org.jboss.logmanager.LogManager</java.util.logging.manager>
<maven.home>${maven.home}</maven.home>
</systemPropertyVariables>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
<properties>
<quarkus.package.type>native</quarkus.package.type>
</properties>
</profile>
</profiles>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.quarkiverse.langchain4j.sample;

import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.Id;

@Entity
public class Customer {

@Id
@GeneratedValue
public Long id;
public String name;
public String email;
public int transactionLimit;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.quarkiverse.langchain4j.sample;

import io.smallrye.config.ConfigMapping;

@ConfigMapping(prefix = "customer")
public interface CustomerConfig {

String name();

String email();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.quarkiverse.langchain4j.sample;

import io.quarkus.hibernate.orm.panache.PanacheRepository;
import jakarta.enterprise.context.ApplicationScoped;

@ApplicationScoped
public class CustomerRepository implements PanacheRepository<Customer> {

/*
* Transaction limit for the customer.
*/
public int getTransactionLimit(String customerName, String customerEmail) {
Customer customer = find("name = ?1 and email = ?2", customerName, customerEmail).firstResult();
if (customer == null) {
throw new MissingCustomerException();
}
return customer.transactionLimit;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.quarkiverse.langchain4j.sample;

import java.time.temporal.ChronoUnit;

import org.eclipse.microprofile.faulttolerance.Timeout;

import dev.langchain4j.service.SystemMessage;
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService(retrievalAugmentor = FraudDetectionRetrievalAugmentor.class)
public interface FraudDetectionAi {

@SystemMessage("""
You are a bank account fraud detection AI. You have to detect frauds in transactions.
""")
@UserMessage("""
Your task is to detect whether a fraud was committed for the customer.
Answer with a **single** JSON document containing:
- the customer name in the 'customer-name' key
- the transaction limit in the 'transaction-limit' key
- the computed sum of all transactions committed during the last 15 minutes in the 'total' key
- the 'fraud' key set to true if the computed sum of all transactions is greater than the transaction limit
- the 'transactions' key containing an array of JSON objects. Each object must have transaction 'amount', 'city' and formatted 'time' keys.
- the 'explanation' key containing an explanation of your answer.
- the 'email' key containing the customer email if the fraud was detected.
Your response must be just the raw JSON document, without ```json, ``` or anything else. Do not use null JSON properties.
""")
@Timeout(value = 2, unit = ChronoUnit.MINUTES)
String detectAmountFraudForCustomer();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package io.quarkiverse.langchain4j.sample;

import java.time.ZoneOffset;
import java.util.List;

import org.jboss.logging.Logger;

import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import io.quarkiverse.langchain4j.sample.SecureMemoryIdProvider.UserNameAndEmail;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.context.control.ActivateRequestContext;
import jakarta.inject.Inject;

@ApplicationScoped
public class FraudDetectionContentRetriever implements ContentRetriever {
private static final Logger log = Logger.getLogger(FraudDetectionContentRetriever.class);

@Inject
TransactionRepository transactionRepository;

@Inject
CustomerRepository customerRepository;

@ActivateRequestContext
@Override
public List<Content> retrieve(Query query) {
UserNameAndEmail userNameAndEmail = UserNameAndEmail.fromString((String) query.metadata().chatMemoryId());
log.infof("Use customer name %s and email %s to retrieve content", userNameAndEmail.getName(),
userNameAndEmail.getEmail());

int transactionLimit = customerRepository.getTransactionLimit(userNameAndEmail.getName(),
userNameAndEmail.getEmail());

List<Transaction> transactions = transactionRepository.getTransactionsForCustomer(userNameAndEmail.getName(),
userNameAndEmail.getEmail());

JsonArray jsonTransactions = new JsonArray();
for (Transaction t : transactions) {
jsonTransactions.add(JsonObject.of("customer-name", t.customerName, "customer-email", t.customerEmail,
"transaction-amount", t.amount, "transaction-city", t.city,
"transaction-time-in-seconds-from-the-epoch", t.time.toEpochSecond(ZoneOffset.UTC)));
}

JsonObject json = JsonObject.of("transaction-limit", transactionLimit, "transactions", jsonTransactions);
return List.of(Content.from(json.toString()));
}
}
Loading

0 comments on commit ee23108

Please sign in to comment.