79168350

Date: 2024-11-07 23:15:14
Score: 1.5
Natty:
Report link

As pointed out by @jakevdp, there is an issue with the efficiency in running the function accuracy due to operations that are not JAX native. Although, I tried to ensure only JAX arrays and fully vectorized operation for compatibility as below;

def accuracy(labels, predictions):
    labels = jnp.array(labels)
    predictions = jnp.array(predictions)
    correct_predictions = jnp.sign(labels) == jnp.sign(predictions)
    acc = jnp.mean(correct_predictions)
    return acc

but still the performance in evaluating accuracy is very bad, infact it's infeasible.

As I am dealing with Quantum Neural Networks, in specific Parameterized Quantum Ciruits as my model, the number of learnable parameters are very few in number. Thus, a more feasible and efficient workaround would be to log the train and test loss (surprisingly the evaluation of test loss is not very slow), atleast this helps in avoiding overfitting. Furthermore, for getting the accuracies, save the parameters at each step using jax.debug.callback, and fetch them to calculate the train and test accuracies at each step after the training process. It can be implemented as below.

import os
import pickle
import time
from datetime import datetime

def save_params(params, step):
    # Get current timestamp
    current_timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    # Directory where params are saved
    dir_path = f"stackoverflow_logs/{current_timestamp}/params"
    os.makedirs(dir_path, exist_ok=True)  # Create directories if they don't exist

    # File path for saving params at the current step
    file_path = os.path.join(dir_path, f"{step}.pkl")
    
    # Save params using pickle
    with open(file_path, "wb") as f:
        pickle.dump(params, f)
    print(f"Parameters saved at step {step} to {file_path}")


@jax.jit
def update_step_jit(i, args):
    params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no, print_training = args
    _data = data[batch_no % num_batch]
    _targets = targets[batch_no % num_batch]
    train_loss, grads = jax.value_and_grad(cost)(params, _data, _targets)
    updates, opt_state = opt.update(grads, opt_state)
    test_loss, _ = jax.value_and_grad(cost)(params, jnp.array(X_test), jnp.array(y_test))
    # Save parameters every step
    jax.debug.callback(lambda params, step: save_params(params, step), params, i)
    params = optax.apply_updates(params, updates)

    def print_fn():
        jax.debug.print("Step: {i}, Train Loss: {train_loss}, Test Loss: {test_loss}", i=i, train_loss=train_loss, test_loss=test_loss)

    jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None)
    return (params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no + 1, print_training)

@jax.jit
def optimization_jit(params, data, targets, X_test, y_test, X_train, y_train, print_training = True):
    opt_state = opt.init(params)
    args = (params, opt_state, data, targets, X_test, y_test, X_train, y_train, 0, print_training)
    (params, _, _, _, _, _, _, _, _, _) = jax.lax.fori_loop(0, 10, update_step_jit, args)
    return params

params = optimization_jit(params, X_batched, y_batched, X_test, y_test, X_train, y_train)

var_train_acc = acc(params, X_train, y_train)
var_test_acc = acc(params, X_test, y_test)

print("Training accuracy: ", var_train_acc)
print("Testing accuracy: ", var_test_acc)

Finally, one can also use any frameworks like wandb, mlflow or aim for tracking in the callback function.

Reasons:
  • Blacklisted phrase (1): stackoverflow
  • Long answer (-1):
  • Has code block (-0.5):
  • User mentioned (1): @jakevdp
  • Self-answer (0.5):
  • Low reputation (0.5):
Posted by: Sup