Skip to content

Commit

Permalink
Merge pull request #122 from VishwamAI/quantum_models_update
Browse files Browse the repository at this point in the history
Update Quantum Models and Documentation
  • Loading branch information
kasinadhsarma authored Sep 27, 2024
2 parents ee90ac4 + a46a360 commit f44e6e6
Show file tree
Hide file tree
Showing 18 changed files with 1,206 additions and 49 deletions.
74 changes: 55 additions & 19 deletions NeuroFlex/quantum_deep_learning/quantum_boltzmann_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

class QuantumBoltzmannMachine:
def __init__(self, num_visible, num_hidden, num_qubits):
if num_visible <= 0 or num_hidden <= 0 or num_qubits <= 0:
raise ValueError("num_visible, num_hidden, and num_qubits must be positive integers")
self.num_visible = num_visible
self.num_hidden = num_hidden
self.num_qubits = num_qubits
self.dev = qml.device("default.qubit", wires=num_qubits)
self.params = np.random.uniform(low=-np.pi, high=np.pi, size=(num_visible + num_hidden, 3))
self.weights = self.params # Initialize weights attribute to store the parameters

@qml.qnode(device=qml.device("default.qubit", wires=1))
def qubit_state(self, params):
Expand All @@ -19,29 +22,50 @@ def qubit_state(self, params):
def initialize_state(self):
return [self.qubit_state(self.params[i]) for i in range(self.num_qubits)]

@qml.qnode(device=qml.device("default.qubit", wires=2))
def entangle_qubits(self, params1, params2):
qml.RX(params1[0], wires=0)
qml.RY(params1[1], wires=0)
qml.RZ(params1[2], wires=0)
qml.RX(params2[0], wires=1)
qml.RY(params2[1], wires=1)
qml.RZ(params2[2], wires=1)
qml.CNOT(wires=[0, 1])
return qml.probs(wires=[0, 1])
@qml.qnode(device=self.dev)
def _entangle_qubits(params1, params2):
# Apply rotation gates to qubit 0
qml.RX(params1[0], wires=0)
qml.RY(params1[1], wires=0)
qml.RZ(params1[2], wires=0)
# Apply rotation gates to qubit 1
qml.RX(params2[0], wires=1)
qml.RY(params2[1], wires=1)
qml.RZ(params2[2], wires=1)
# Entangle qubits with CNOT gate
qml.CNOT(wires=[0, 1])
# Return probabilities of the two-qubit state
return qml.probs(wires=[0, 1])
return _entangle_qubits(params1, params2)

def energy(self, visible_state, hidden_state):
energy = 0
energy = 0.0
for i in range(self.num_visible):
for j in range(self.num_hidden):
energy += self.entangle_qubits(self.params[i], self.params[self.num_visible + j])[3] * visible_state[i] * hidden_state[j]
return -energy
params1 = self.params[i]
params2 = self.params[self.num_visible + j]
# Ensure params1 and params2 are correctly defined before passing to entangle_qubits
if params1.shape != (3,) or params2.shape != (3,):
raise ValueError(f"Invalid parameter shapes: params1 {params1.shape}, params2 {params2.shape}")
entangled_state = self.entangle_qubits(params1, params2)
# Ensure entangled_state is a 1D array and has at least 4 elements
if entangled_state.ndim == 1 and entangled_state.shape[0] >= 4:
# Use the absolute value of the last element of entangled_state as the interaction strength
interaction_strength = abs(float(entangled_state[-1]))
energy += interaction_strength * float(visible_state[i]) * float(hidden_state[j])
else:
raise ValueError(f"Unexpected shape of entangled_state: {entangled_state.shape}")
return float(-energy) # Return negative energy as float to align with minimization objective

def sample_hidden(self, visible_state):
hidden_probs = np.zeros(self.num_hidden)
for j in range(self.num_hidden):
hidden_probs[j] = np.mean([self.entangle_qubits(self.params[i], self.params[self.num_visible + j])[3]
for i in range(self.num_visible) if visible_state[i] == 1])
hidden_probs[j] = np.mean([
self.entangle_qubits(params1=self.params[i], params2=self.params[self.num_visible + j])[3]
for i in range(self.num_visible)
if visible_state[i] == 1 and self.params[i].shape == (3,) and self.params[self.num_visible + j].shape == (3,)
])
return (np.random.random(self.num_hidden) < hidden_probs).astype(int)

def sample_visible(self, hidden_state):
Expand All @@ -53,19 +77,31 @@ def sample_visible(self, hidden_state):

def train(self, data, num_epochs, learning_rate):
for epoch in range(num_epochs):
total_energy = 0
for visible_data in data:
hidden_data = self.sample_hidden(visible_data)
visible_model = self.sample_visible(hidden_data)
hidden_model = self.sample_hidden(visible_model)

# Calculate gradients
gradients = np.zeros_like(self.params)
for i in range(self.num_visible):
for j in range(self.num_hidden):
data_term = visible_data[i] * hidden_data[j]
model_term = visible_model[i] * hidden_model[j]
grad = data_term - model_term
gradients[i] += grad
gradients[self.num_visible + j] += grad

# Update parameters
for i in range(self.num_visible + self.num_hidden):
grad = visible_data[i % self.num_visible] * hidden_data[i % self.num_hidden] - \
visible_model[i % self.num_visible] * hidden_model[i % self.num_hidden]
self.params[i] += learning_rate * grad
self.params -= learning_rate * gradients

# Calculate energy for this sample
total_energy += self.energy(visible_data, hidden_data)

avg_energy = total_energy / len(data)
if epoch % 10 == 0:
print(f"Epoch {epoch}, Energy: {self.energy(visible_data, hidden_data)}")
print(f"Epoch {epoch}, Average Energy: {avg_energy}")

def generate_sample(self, num_steps):
visible_state = np.random.randint(2, size=self.num_visible)
Expand Down
143 changes: 131 additions & 12 deletions NeuroFlex/quantum_deep_learning/quantum_cnn.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,154 @@
import numpy as np
import pennylane as qml
import jax
import jax.numpy as jnp

@qml.qnode(device=qml.device("default.qubit", wires=1), interface="jax")
def qubit_layer(params, input_val):
qml.RX(input_val, wires=0)
qml.RY(params[0], wires=0)
qml.RZ(params[1], wires=0)
return qml.expval(qml.PauliZ(0)) # Return scalar value directly

class QuantumCNN:
def __init__(self, num_qubits, num_layers):
self.num_qubits = num_qubits
self.num_layers = num_layers
self.dev = qml.device("default.qubit", wires=num_qubits)
self.params = np.random.uniform(low=-np.pi, high=np.pi, size=(num_layers, num_qubits, 3))
self.params = None

def init(self, key, input_shape):
self.key = key
self.input_shape = input_shape
# Initialize parameters randomly
self.params = jax.random.normal(key, (self.num_layers, self.num_qubits, 2))
return self.params

def apply(self, params, x):
self.params = params
return self.forward(x)

def quantum_conv_layer(self, params, x):
print(f"quantum_conv_layer input shape: {x.shape}, values: {x}")
print(f"quantum_conv_layer params shape: {params.shape}, values: {params}")

if x.shape[0] == 1:
# If input size is 1, return the input as is
output = []
for j in range(self.num_qubits):
input_val = float(x[0, j])
qubit_output = qubit_layer(params[j], input_val=input_val)
output.append(qubit_output)
output_array = jnp.array([output])
else:
output = []
for i in range(x.shape[0] - 1): # Reduce output size by 1
layer_output = []
for j in range(self.num_qubits):
input_val = float((x[i, j] + x[i+1, j]) / 2) # Ensure input_val is a scalar
print(f"input_val for qubit {j}: {input_val}")
qubit_output = qubit_layer(params[j], input_val=input_val) # Average of two adjacent inputs
print(f"qubit_output for qubit {j}: {qubit_output}")
layer_output.append(qubit_output)
output.append(layer_output)
output_array = jnp.array(output)

print(f"quantum_conv_layer output shape: {output_array.shape}, values: {output_array}")
return output_array

def forward(self, x):
for layer in range(self.num_layers):
x = self.quantum_conv_layer(self.params[layer], x)
return x

@qml.qnode(device=qml.device("default.qubit", wires=1))
def qubit_layer(self, params):
qml.RX(params[0], wires=0)
qml.RY(params[1], wires=0)
qml.RZ(params[2], wires=0)
return qml.expval(qml.PauliZ(0))
@qml.qnode(qml.device('default.qubit', wires=1))
def qubit_layer(params, input_val):
qml.RX(input_val, wires=0)
qml.RY(params[0], wires=0)
qml.RZ(params[1], wires=0)
return qml.expval(qml.PauliZ(0))

def _qubit_layer(self, params, input_val):
return qubit_layer(params, input_val)

def quantum_conv_layer(self, inputs, params):
outputs = []
for i in range(len(inputs) - 1):
qml.RX(inputs[i], wires=0)
print(f"quantum_conv_layer input shape: {inputs.shape}")
print(f"quantum_conv_layer params shape: {params.shape}")
def apply_layer(i, x):
qml.RX(x, wires=0)
qml.RY(inputs[i+1], wires=1)
qml.CNOT(wires=[0, 1])
outputs.append(self.qubit_layer(params))
return np.array(outputs)
param_index = jax.lax.rem(i, params.shape[0])
slice_params = jax.lax.dynamic_slice(params, (param_index, 0), (1, params.shape[1])).squeeze()
print(f"apply_layer i={i}, x={x}, param_index={param_index}, slice_params={slice_params}")
return self.qubit_layer(params=slice_params, input_val=x)

if inputs.shape[0] > 1:
outputs = jax.vmap(apply_layer, in_axes=(0, 0))(jnp.arange(inputs.shape[0] - 1), inputs[:-1])
print(f"quantum_conv_layer output shape: {outputs.shape}")
reshaped_outputs = outputs.reshape(-1, self.num_qubits)
print(f"quantum_conv_layer reshaped output shape: {reshaped_outputs.shape}")
return reshaped_outputs # Reshape to ensure consistent output shape
else:
print("quantum_conv_layer: input shape <= 1, returning zeros")
return jnp.zeros((1, self.num_qubits)) # Return a 2D array with one row and correct number of columns

def forward(self, inputs):
x = inputs
print(f"forward input shape: {x.shape}")
for layer in range(self.num_layers):
x = self.quantum_conv_layer(x, self.params[layer])
print(f"Layer {layer} output shape: {x.shape}")
if x.shape[0] <= 1:
print(f"Cannot reduce further, breaking at layer {layer}")
break # Stop if we can't reduce further

# Ensure the output is 2D
if x.ndim == 1:
x = x.reshape(1, -1)

# If the output has more than one column, take the mean across columns
if x.shape[1] > 1:
x = x.mean(axis=1, keepdims=True)

print(f"Final output shape: {x.shape}")
return x

def gradient_step(self, inputs, targets, learning_rate):
print(f"gradient_step input shape: {inputs.shape}")
def loss_fn(params):
predictions = self.apply(params, inputs)
return jnp.mean(jnp.square(predictions - targets))

grads = jax.grad(loss_fn)(self.params)
print(f"Gradients: {jax.tree_map(lambda x: jnp.sum(jnp.abs(x)), grads)}")
print(f"Gradient details: {jax.tree_map(lambda x: x, grads)}")

new_params = jax.tree_map(lambda p, g: p - learning_rate * g, self.params, grads)
print(f"Param diff: {jax.tree_map(lambda x, y: jnp.sum(jnp.abs(x - y)), self.params, new_params)}")
print(f"New params: {jax.tree_map(lambda x: x, new_params)}")

self.params = new_params

def calculate_output_shape(self, input_shape):
output_shape = input_shape[0]
for _ in range(self.num_layers):
output_shape -= 1
return (output_shape, 1)

def apply(self, params, inputs):
return self.forward(inputs)

def loss(self, inputs, targets):
predictions = self.forward(inputs)
return jnp.mean((predictions - targets) ** 2)

def apply(self, params, inputs):
self.params = params
return self.forward(inputs)



def loss(self, inputs, targets):
predictions = self.forward(inputs)
return np.mean((predictions - targets) ** 2)
Expand Down
79 changes: 65 additions & 14 deletions NeuroFlex/quantum_deep_learning/quantum_rnn.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,76 @@
import numpy as np
import jax
import jax.numpy as jnp
import pennylane as qml

class QuantumRNN:
def __init__(self, num_qubits, num_layers):
self.num_qubits = num_qubits
self.num_layers = num_layers
self.dev = qml.device("default.qubit", wires=num_qubits)
self.params = np.random.uniform(low=-np.pi, high=np.pi, size=(num_layers, num_qubits, 3))
self.param_init_fn = lambda key, shape: np.random.uniform(low=-np.pi, high=np.pi, size=shape)
self.params = self.param_init_fn(None, (num_layers, num_qubits, 2))

def init(self, key, input_shape):
self.key = key
self.input_shape = input_shape
# Initialize params for 3D input shape (batch_size, time_steps, features)
self.params = self.param_init_fn(key, (self.num_layers, self.num_qubits, 2))
return self.params

@qml.qnode(device=qml.device("default.qubit", wires=1))
def qubit_layer(self, params, input_val):
qml.RX(input_val, wires=0)
qml.RY(params[0], wires=0)
qml.RZ(params[1], wires=0)
return qml.expval(qml.PauliZ(0))
dev = qml.device("default.qubit", wires=self.num_qubits)
@qml.qnode(dev)
def _qubit_circuit(params, input_val):
input_val = jnp.atleast_2d(input_val)
for b in range(input_val.shape[0]): # Iterate over batch
for i in range(self.num_qubits):
qml.RX(input_val[b, i % input_val.shape[1]], wires=i)
qml.RY(params[i][0], wires=i)
qml.RZ(params[i][1], wires=i)
return [qml.expval(qml.PauliZ(i)) for i in range(self.num_qubits)]
return jnp.array([_qubit_circuit(params, jnp.atleast_2d(iv)) for iv in input_val])

def quantum_rnn_layer(self, inputs, params, hidden_state):
outputs = []
for t in range(len(inputs)):
qml.RX(hidden_state, wires=0)
qml.RY(inputs[t], wires=1)
qml.CNOT(wires=[0, 1])
hidden_state = self.qubit_layer(params, inputs[t])
outputs.append(hidden_state)
return np.array(outputs), hidden_state
# Handle both 2D and 3D input shapes
if inputs.ndim == 2:
batch_size, input_features = inputs.shape
time_steps = 1
inputs = jnp.reshape(inputs, (batch_size, time_steps, input_features))
elif inputs.ndim == 3:
batch_size, time_steps, input_features = inputs.shape
else:
raise ValueError(f"Expected 2D or 3D input, got shape {inputs.shape}")

hidden_features = self.num_qubits # Number of features in hidden state

# Initialize hidden_state if it's the first pass
if isinstance(hidden_state, int) and hidden_state == 0:
hidden_state = jnp.zeros((batch_size, hidden_features))

def scan_fn(carry, x):
hidden_state = carry
input_t = x

# Combine hidden state and input, ensuring dimensions match
combined_input = jnp.concatenate([hidden_state, input_t], axis=1)

# Adjust combined_input if necessary to match num_qubits
if combined_input.shape[1] > self.num_qubits:
combined_input = combined_input[:, :self.num_qubits]
elif combined_input.shape[1] < self.num_qubits:
pad_width = self.num_qubits - combined_input.shape[1]
combined_input = jnp.pad(combined_input, ((0, 0), (0, pad_width)), mode='constant')

# Ensure input_val shape matches the expected shape for qubit_layer
input_val = jnp.reshape(combined_input, (-1, self.num_qubits))
new_hidden_state = self.qubit_layer(params=params, input_val=input_val)
return new_hidden_state, new_hidden_state

hidden_state, outputs = jax.lax.scan(scan_fn, hidden_state, jnp.transpose(inputs, (1, 0, 2)))
outputs = jnp.transpose(outputs, (1, 0, 2))

return outputs, hidden_state

def forward(self, inputs):
hidden_state = 0
Expand All @@ -32,6 +79,10 @@ def forward(self, inputs):
x, hidden_state = self.quantum_rnn_layer(x, self.params[layer], hidden_state)
return x

def apply(self, params, inputs):
self.params = params
return self.forward(inputs)

def loss(self, inputs, targets):
predictions = self.forward(inputs)
return np.mean((predictions - targets) ** 2)
Expand Down
7 changes: 7 additions & 0 deletions NeuroFlex/quantum_neural_networks/quantum_nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
MAX_HEALING_ATTEMPTS
)

# Check PennyLane version
try:
pennylane_version = importlib.metadata.version("pennylane")
if pennylane_version != "0.37.0":
logging.warning(f"This module was tested with PennyLane 0.37.0. You are using version {pennylane_version}. Some features may not work as expected.")
except importlib.metadata.PackageNotFoundError:
logging.warning("Unable to determine PennyLane version. Make sure it's installed correctly.")
# Check PennyLane version
try:
pennylane_version = importlib.metadata.version("pennylane")
Expand Down
Loading

0 comments on commit f44e6e6

Please sign in to comment.