Hi @Bnesh ,
Thank you for making the MWE. I can replicate the behaviour. When you comment the tf.function decorators the code runs but when I uncomment them I get this error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-4-3d54eda5944b> in <cell line: 93>()
91 optimizer = keras.optimizers.Adam(1e-3)
92 steps = 100
---> 93 training(steps)
94 print("Training Done")
10 frames
<ipython-input-4-3d54eda5944b> in training(steps)
78 def training(steps):
79 for i in range(steps):
---> 80 lossl = train_step(full_samples(Nr))
81 print(lossl.numpy())
82
/usr/local/lib/python3.10/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
/tmp/__autograph_generated_fileddnlkzds.py in tf__train_step(xr)
8 do_return = False
9 retval_ = ag__.UndefinedReturnValue()
---> 10 (loss_value, grads) = ag__.converted_call(ag__.ld(grad), (ag__.ld(xr),), None, fscope)
11 clipped_grads = [ag__.converted_call(ag__.ld(tf).clip_by_value, (ag__.ld(g), -1, 1), None, fscope) for g in ag__.ld(grads)]
12 ag__.converted_call(ag__.ld(optimizer).apply_gradients, (ag__.converted_call(ag__.ld(zip), (ag__.ld(clipped_grads), ag__.ld(model).trainable_variables), None, fscope),), None, fscope)
/tmp/__autograph_generated_filerddf9hx9.py in tf__grad(xr)
9 retval_ = ag__.UndefinedReturnValue()
10 with ag__.ld(tf).GradientTape() as tape:
---> 11 loss_value = ag__.converted_call(ag__.ld(loss), (ag__.ld(xr),), None, fscope)
12 try:
13 do_return = True
/tmp/__autograph_generated_fileoyk4ym58.py in tf__loss(xr)
8 do_return = False
9 retval_ = ag__.UndefinedReturnValue()
---> 10 pred = ag__.converted_call(ag__.ld(tf).squeeze, (ag__.converted_call(ag__.ld(model), (ag__.ld(xr),), None, fscope),), None, fscope)
11 true = ag__.converted_call(ag__.ld(true_sol), (ag__.ld(xr),), None, fscope)
12 loss = 1 / ag__.ld(Nr) * ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.converted_call(ag__.ld(tf).math.square, (ag__.ld(pred) - ag__.ld(true),), None, fscope),), None, fscope)
/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
/usr/local/lib/python3.10/dist-packages/pennylane/qnn/keras.py in tf__call(self, inputs)
35 batch_dims = ag__.Undefined('batch_dims')
36 ag__.if_stmt(ag__.ld(has_batch_dim), if_body, else_body, get_state, set_state, ('batch_dims', 'inputs'), 2)
---> 37 results = ag__.converted_call(ag__.ld(self)._evaluate_qnode, (ag__.ld(inputs),), None, fscope)
38
39 def get_state_1():
/usr/local/lib/python3.10/dist-packages/pennylane/qnn/keras.py in tf___evaluate_qnode(self, x)
59 do_return = False
60 raise
---> 61 ag__.if_stmt(ag__.converted_call(ag__.ld(isinstance), (ag__.ld(res), (ag__.ld(list), ag__.ld(tuple))), None, fscope), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_', 'res'), 2)
62 return fscope.ret(retval_, do_return)
63 return tf___evaluate_qnode
/usr/local/lib/python3.10/dist-packages/pennylane/qnn/keras.py in if_body_1()
43 nonlocal res
44 pass
---> 45 ag__.if_stmt(ag__.converted_call(ag__.ld(len), (ag__.ld(x).shape,), None, fscope) > 1, if_body, else_body, get_state, set_state, ('res',), 1)
46 try:
47 do_return = True
/usr/local/lib/python3.10/dist-packages/pennylane/qnn/keras.py in if_body()
38 def if_body():
39 nonlocal res
---> 40 res = [ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(r), (ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(x),), None, fscope)[0], ag__.converted_call(ag__.ld(tf).reduce_prod, (ag__.ld(r).shape[1:],), None, fscope))), None, fscope) for r in ag__.ld(res)]
41
42 def else_body():
/usr/local/lib/python3.10/dist-packages/pennylane/qnn/keras.py in <listcomp>(.0)
38 def if_body():
39 nonlocal res
---> 40 res = [ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(r), (ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(x),), None, fscope)[0], ag__.converted_call(ag__.ld(tf).reduce_prod, (ag__.ld(r).shape[1:],), None, fscope))), None, fscope) for r in ag__.ld(res)]
41
42 def else_body():
ValueError: in user code:
File "<ipython-input-3-78b4de627fee>", line 73, in train_step *
loss_value, grads = grad(xr)
File "<ipython-input-4-3d54eda5944b>", line 68, in grad *
loss_value = loss(xr)
File "<ipython-input-4-3d54eda5944b>", line 58, in loss *
pred = tf.squeeze(model(xr))
File "/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler **
raise e.with_traceback(filtered_tb) from None
File "/tmp/__autograph_generated_file85erbwpk.py", line 37, in tf__call
results = ag__.converted_call(ag__.ld(self)._evaluate_qnode, (ag__.ld(inputs),), None, fscope)
File "/tmp/__autograph_generated_filexj74w4xm.py", line 61, in tf___evaluate_qnode
ag__.if_stmt(ag__.converted_call(ag__.ld(isinstance), (ag__.ld(res), (ag__.ld(list), ag__.ld(tuple))), None, fscope), if_body_1, else_body_1, get_state_1, set_state_1, ('do_return', 'retval_', 'res'), 2)
File "/tmp/__autograph_generated_filexj74w4xm.py", line 45, in if_body_1
ag__.if_stmt(ag__.converted_call(ag__.ld(len), (ag__.ld(x).shape,), None, fscope) > 1, if_body, else_body, get_state, set_state, ('res',), 1)
File "/tmp/__autograph_generated_filexj74w4xm.py", line 40, in if_body
res = [ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(r), (ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(x),), None, fscope)[0], ag__.converted_call(ag__.ld(tf).reduce_prod, (ag__.ld(r).shape[1:],), None, fscope))), None, fscope) for r in ag__.ld(res)]
File "/tmp/__autograph_generated_filexj74w4xm.py", line 40, in <listcomp>
res = [ag__.converted_call(ag__.ld(tf).reshape, (ag__.ld(r), (ag__.converted_call(ag__.ld(tf).shape, (ag__.ld(x),), None, fscope)[0], ag__.converted_call(ag__.ld(tf).reduce_prod, (ag__.ld(r).shape[1:],), None, fscope))), None, fscope) for r in ag__.ld(res)]
ValueError: Exception encountered when calling layer 'keras_layer_2' (type KerasLayer).
in user code:
File "/usr/local/lib/python3.10/dist-packages/pennylane/qnn/keras.py", line 414, in call *
results = self._evaluate_qnode(inputs)
File "/usr/local/lib/python3.10/dist-packages/pennylane/qnn/keras.py", line 442, in _evaluate_qnode *
res = [tf.reshape(r, (tf.shape(x)[0], tf.reduce_prod(r.shape[1:]))) for r in res]
ValueError: Cannot convert a partially known TensorShape <unknown> to a Tensor.
Call arguments received by layer 'keras_layer_2' (type KerasLayer):
• inputs=tf.Tensor(shape=(64, 2), dtype=float64)
I’m wondering if it has anything to do with this bug. I’m not sure.
I don’t know if it’s possible in your case but something that can help with speedups is using Jax instead of TensorFlow. I understand that you probably have a whole workflow using TensorFlow and it’s not guaranteed that it will work with Jax, but if the speed issue is really blocking you maybe you can try it.