File tree Expand file tree Collapse file tree 2 files changed +14
-1
lines changed Expand file tree Collapse file tree 2 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -198,7 +198,15 @@ class Det(Op):
198198
199199 def make_node (self , x ):
200200 x = as_tensor_variable (x )
201- assert x .ndim == 2
201+ if x .ndim != 2 :
202+ raise ValueError (
203+ f"Input passed is not a valid 2D matrix. Current ndim { x .ndim } != 2"
204+ )
205+ # Check for known shapes and square matrix
206+ if None not in x .type .shape and (x .type .shape [0 ] != x .type .shape [1 ]):
207+ raise ValueError (
208+ f"Determinant not defined for non-square matrix inputs. Shape received is { x .type .shape } "
209+ )
202210 o = scalar (dtype = x .dtype )
203211 return Apply (self , [x ], [o ])
204212
Original file line number Diff line number Diff line change @@ -365,6 +365,11 @@ def test_det():
365365 assert np .allclose (np .linalg .det (r ), f (r ))
366366
367367
368+ def test_det_non_square_raises ():
369+ with pytest .raises (ValueError , match = "Determinant not defined" ):
370+ det (tensor ("x" , shape = (5 , 7 )))
371+
372+
368373def test_det_grad ():
369374 rng = np .random .default_rng (utt .fetch_seed ())
370375
You can’t perform that action at this time.
0 commit comments