@@ -295,16 +295,19 @@ def main(x):
295295def ewise_add (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
296296 assert left .ndims == right .ndims
297297 assert left .dtype == right .dtype
298+
299+ if left ._obj is None :
300+ return right
301+ if right ._obj is None :
302+ return left
303+
298304 assert left ._sparsity == right ._sparsity
299305
300306 rank = left .ndims
301307 if rank == 0 : # Scalar
302308 # TODO: implement this
303309 raise NotImplementedError ("doesn't yet work for Scalar" )
304310
305- # TODO: handle case of either left or right not having an _obj -> result will be other for ewise_add
306- # or have a utility to build an empty MLIRSparseTensor for all input tensors?
307-
308311 # Build and compile if needed
309312 key = ('ewise_add' , op .name , * left .get_loop_key (), * right .get_loop_key ())
310313 if key not in engine_cache :
@@ -363,15 +366,19 @@ def main(x, y):
363366def ewise_mult (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
364367 assert left .ndims == right .ndims
365368 assert left .dtype == right .dtype
369+
370+ if left ._obj is None :
371+ return left
372+ if right ._obj is None :
373+ return right
374+
366375 assert left ._sparsity == right ._sparsity
367376
368377 rank = left .ndims
369378 if rank == 0 : # Scalar
370379 # TODO: implement this
371380 raise NotImplementedError ("doesn't yet work for Scalar" )
372381
373- # TODO: handle case of either left or right not having an _obj -> result will be empty for ewise_mult
374-
375382 # Build and compile if needed
376383 key = ('ewise_mult' , op .name , * left .get_loop_key (), * right .get_loop_key ())
377384 if key not in engine_cache :
@@ -433,9 +440,12 @@ def main(x, y):
433440def mxm (op : Semiring , left : Union [Matrix , TransposedMatrix ], right : Union [Matrix , TransposedMatrix ]):
434441 assert left .ndims == right .ndims == 2
435442 assert left .dtype == right .dtype
436- assert left ._sparsity == right ._sparsity
437443
438- # TODO: handle case of either left or right not having an _obj -> result will be empty for mxm
444+ optype = op .binop .get_output_type (left .dtype , right .dtype )
445+ if left ._obj is None or right ._obj is None :
446+ return Matrix .new (optype , left .shape [0 ], right .shape [1 ])
447+
448+ assert left ._sparsity == right ._sparsity
439449
440450 # Build and compile if needed
441451 key = ('mxm' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -446,7 +456,7 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
446456 mem_out = get_sparse_output_pointer ()
447457 arg_pointers = [left ._obj , right ._obj , mem_out ]
448458 engine_cache [key ].invoke ('main' , * arg_pointers )
449- return Matrix (op . binop . get_output_type ( left . dtype , right . dtype ) , [left .shape [0 ], right .shape [1 ]], mem_out ,
459+ return Matrix (optype , [left .shape [0 ], right .shape [1 ]], mem_out ,
450460 left ._sparsity , left .perceived_ordering , intermediate_result = True )
451461
452462
@@ -509,9 +519,10 @@ def main(x, y):
509519def mxv (op : Semiring , left : Union [Matrix , TransposedMatrix ], right : Vector ):
510520 assert left .ndims == 2
511521 assert right .ndims == 1
512- assert left .dtype == right .dtype
513522
514- # TODO: handle case of either left or right not having an _obj -> result will be empty for mxv
523+ optype = op .binop .get_output_type (left .dtype , right .dtype )
524+ if left ._obj is None or right ._obj is None :
525+ return Vector .new (optype , left .shape [0 ])
515526
516527 # Build and compile if needed
517528 key = ('mxv' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -522,7 +533,7 @@ def mxv(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector):
522533 mem_out = get_sparse_output_pointer ()
523534 arg_pointers = [left ._obj , right ._obj , mem_out ]
524535 engine_cache [key ].invoke ('main' , * arg_pointers )
525- return Vector (op . binop . get_output_type ( left . dtype , right . dtype ) , [left .shape [0 ]], mem_out ,
536+ return Vector (optype , [left .shape [0 ]], mem_out ,
526537 right ._sparsity , right .perceived_ordering , intermediate_result = True )
527538
528539
@@ -583,9 +594,10 @@ def main(x, y):
583594def vxm (op : Semiring , left : Vector , right : Union [Matrix , TransposedMatrix ]):
584595 assert left .ndims == 1
585596 assert right .ndims == 2
586- assert left .dtype == right .dtype
587597
588- # TODO: handle case of either left or right not having an _obj -> result will be empty for vxm
598+ optype = op .binop .get_output_type (left .dtype , right .dtype )
599+ if left ._obj is None or right ._obj is None :
600+ return Vector .new (optype , right .shape [1 ])
589601
590602 # Build and compile if needed
591603 key = ('vxm' , op .name , * left .get_loop_key (), * right .get_loop_key ())
@@ -596,7 +608,7 @@ def vxm(op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]):
596608 mem_out = get_sparse_output_pointer ()
597609 arg_pointers = [left ._obj , right ._obj , mem_out ]
598610 engine_cache [key ].invoke ('main' , * arg_pointers )
599- return Vector (op . binop . get_output_type ( left . dtype , right . dtype ) , [right .shape [1 ]], mem_out ,
611+ return Vector (optype , [right .shape [1 ]], mem_out ,
600612 left ._sparsity , left .perceived_ordering , intermediate_result = True )
601613
602614
@@ -664,26 +676,34 @@ def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp],
664676 # TODO: implement this
665677 raise NotImplementedError ("doesn't yet work for Scalar" )
666678
667- # TODO: handle case of empty input (must figure out correct output dtype)
668-
669- # Build and compile if needed
670- # Note that Scalars are included in the key because they are inlined in the compiled code
679+ # Find output dtype
671680 optype = type (op )
672681 if optype is UnaryOp :
673- key = ('apply_unary' , op .name , * sp .get_loop_key (), inplace )
674682 output_dtype = op .get_output_type (sp .dtype )
675683 elif optype is BinaryOp :
676684 if left is not None :
677- key = ('apply_bind_first' , op .name , * sp .get_loop_key (), left ._obj , inplace )
678685 output_dtype = op .get_output_type (left .dtype , sp .dtype )
679686 else :
680- key = ('apply_bind_second' , op .name , * sp .get_loop_key (), right ._obj , inplace )
681687 output_dtype = op .get_output_type (sp .dtype , right .dtype )
682688 else :
683689 if inplace :
684690 raise TypeError ("apply inplace not supported for IndexUnaryOp" )
685- key = ('apply_indexunary' , op .name , * sp .get_loop_key (), thunk ._obj )
686691 output_dtype = op .get_output_type (sp .dtype , thunk .dtype )
692+
693+ if sp ._obj is None :
694+ return sp .baseclass (output_dtype , sp .shape )
695+
696+ # Build and compile if needed
697+ # Note that Scalars are included in the key because they are inlined in the compiled code
698+ if optype is UnaryOp :
699+ key = ('apply_unary' , op .name , * sp .get_loop_key (), inplace )
700+ elif optype is BinaryOp :
701+ if left is not None :
702+ key = ('apply_bind_first' , op .name , * sp .get_loop_key (), left ._obj , inplace )
703+ else :
704+ key = ('apply_bind_second' , op .name , * sp .get_loop_key (), right ._obj , inplace )
705+ else :
706+ key = ('apply_indexunary' , op .name , * sp .get_loop_key (), thunk ._obj )
687707 if key not in engine_cache :
688708 if inplace :
689709 engine_cache [key ] = _build_apply_inplace (op , sp , left , right )
@@ -887,7 +907,8 @@ def main(x):
887907
888908
889909def reduce_to_vector (op : Monoid , mat : Union [Matrix , TransposedMatrix ]):
890- # TODO: handle case of mat not having an _obj -> result will be empty vector
910+ if mat ._obj is None :
911+ return Vector .new (mat .dtype , mat .shape [0 ])
891912
892913 # Build and compile if needed
893914 key = ('reduce_to_vector' , op .name , * mat .get_loop_key ())
@@ -944,7 +965,8 @@ def main(x):
944965
945966
946967def reduce_to_scalar (op : Monoid , sp : SparseTensorBase ):
947- # TODO: handle case of sp not having an _obj -> result will be empty scalar
968+ if sp ._obj is None :
969+ return Scalar .new (sp .dtype )
948970
949971 # Build and compile if needed
950972 key = ('reduce_to_scalar' , op .name , * sp .get_loop_key ())
0 commit comments