1717from .compiler import compile , engine_cache
1818from . descriptor import Descriptor , NULL as NULL_DESC
1919from .utils import (get_sparse_output_pointer , get_scalar_output_pointer ,
20- get_scalar_input_arg , pick_and_renumber_indices )
20+ get_scalar_input_arg , pick_and_renumber_indices , determine_sparsity )
2121from .types import RankedTensorType , BOOL , INT64 , FP64
2222from .exceptions import GrbError , GrbIndexOutOfBounds , GrbDimensionMismatch
2323
@@ -34,7 +34,6 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor =
3434 in `sp` correspond to missing or "falsy" elements in the mask.
3535 """
3636 assert mask .ndims == sp .ndims
37- assert mask ._sparsity == sp ._sparsity
3837 if mask .shape != sp .shape :
3938 raise GrbDimensionMismatch (f"Mask shape mismatch: { mask .shape } != { sp .shape } " )
4039
@@ -62,7 +61,7 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor =
6261 mem_out = get_sparse_output_pointer ()
6362 arg_pointers = [mask ._obj , sp ._obj , mem_out ]
6463 engine_cache [key ].invoke ('main' , * arg_pointers )
65- return mask .baseclass (sp .dtype , mask .shape , mem_out , mask . _sparsity ,
64+ return mask .baseclass (sp .dtype , mask .shape , mem_out , determine_sparsity ( mask , sp ) ,
6665 mask .perceived_ordering , intermediate_result = True )
6766
6867
@@ -80,7 +79,8 @@ def _build_select_by_mask(mask: SparseTensor, sp: SparseTensorBase, complement:
8079 perm_out = ir .AffineMap .get_permutation (range (rank ))
8180 rtt_sp = sp .rtt .as_mlir_type ()
8281 rtt_mask = mask .rtt .as_mlir_type ()
83- rtt_out = mask .rtt .copy (dtype = sp .dtype ).as_mlir_type ()
82+ rtt_out = mask .rtt .copy (dtype = sp .dtype ,
83+ sparsity = determine_sparsity (mask , sp )).as_mlir_type ()
8484
8585 @func .FuncOp .from_py_func (rtt_mask , rtt_sp )
8686 def main (msk , x ):
@@ -368,8 +368,6 @@ def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
368368 engine_cache [key ].invoke ('main' , * arg_pointers )
369369 return Scalar (left .dtype , (), left .dtype .np_type (mem_out .contents .value ))
370370
371- assert left ._sparsity == right ._sparsity
372-
373371 # Build and compile if needed
374372 key = ('ewise_add' , op .name , * left .get_loop_key (), * right .get_loop_key ())
375373 if key not in engine_cache :
@@ -380,7 +378,8 @@ def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
380378 arg_pointers = [left ._obj , right ._obj , mem_out ]
381379 engine_cache [key ].invoke ('main' , * arg_pointers )
382380 return left .baseclass (op .get_output_type (left .dtype , right .dtype ), left .shape , mem_out ,
383- left ._sparsity , left .perceived_ordering , intermediate_result = True )
381+ determine_sparsity (left , right , union = True ), left .perceived_ordering ,
382+ intermediate_result = True )
384383
385384
386385def _build_ewise_add (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
@@ -395,7 +394,8 @@ def _build_ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBa
395394 perm_out = ir .AffineMap .get_permutation (range (rank ))
396395 rtt_left = left .rtt .as_mlir_type ()
397396 rtt_right = right .rtt .as_mlir_type ()
398- rtt_out = left .rtt .copy (ordering = left .perceived_ordering ).as_mlir_type ()
397+ rtt_out = left .rtt .copy (ordering = left .perceived_ordering ,
398+ sparsity = determine_sparsity (left , right , union = True )).as_mlir_type ()
399399
400400 @func .FuncOp .from_py_func (rtt_left , rtt_right )
401401 def main (x , y ):
@@ -443,8 +443,6 @@ def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
443443 engine_cache [key ].invoke ('main' , * arg_pointers )
444444 return Scalar (output_dtype , (), output_dtype .np_type (mem_out .contents .value ))
445445
446- assert left ._sparsity == right ._sparsity
447-
448446 # Build and compile if needed
449447 key = ('ewise_mult' , op .name , * left .get_loop_key (), * right .get_loop_key ())
450448 if key not in engine_cache :
@@ -455,7 +453,8 @@ def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase):
455453 arg_pointers = [left ._obj , right ._obj , mem_out ]
456454 engine_cache [key ].invoke ('main' , * arg_pointers )
457455 return left .baseclass (output_dtype , left .shape , mem_out ,
458- left ._sparsity , left .perceived_ordering , intermediate_result = True )
456+ determine_sparsity (left , right ), left .perceived_ordering ,
457+ intermediate_result = True )
459458
460459
461460def _build_ewise_mult (op : BinaryOp , left : SparseTensorBase , right : SparseTensorBase ):
@@ -472,7 +471,9 @@ def _build_ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorB
472471 perm_out = ir .AffineMap .get_permutation (range (rank ))
473472 rtt_left = left .rtt .as_mlir_type ()
474473 rtt_right = right .rtt .as_mlir_type ()
475- rtt_out = left .rtt .copy (dtype = op_result_dtype , ordering = left .perceived_ordering ).as_mlir_type ()
474+ rtt_out = RankedTensorType (dtype = op_result_dtype ,
475+ sparsity = determine_sparsity (left , right ),
476+ ordering = left .perceived_ordering ).as_mlir_type ()
476477
477478 @func .FuncOp .from_py_func (rtt_left , rtt_right )
478479 def main (x , y ):
@@ -511,8 +512,6 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
511512 if left ._obj is None or right ._obj is None :
512513 return Matrix .new (optype , left .shape [0 ], right .shape [1 ])
513514
514- assert left ._sparsity == right ._sparsity
515-
516515 # Build and compile if needed
517516 key = ('mxm' , op .name , * left .get_loop_key (), * right .get_loop_key ())
518517 if key not in engine_cache :
@@ -523,7 +522,7 @@ def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix
523522 arg_pointers = [left ._obj , right ._obj , mem_out ]
524523 engine_cache [key ].invoke ('main' , * arg_pointers )
525524 return Matrix (optype , [left .shape [0 ], right .shape [1 ]], mem_out ,
526- left . _sparsity , left .perceived_ordering , intermediate_result = True )
525+ determine_sparsity ( left , right ) , left .perceived_ordering , intermediate_result = True )
527526
528527
529528def _build_mxm (op : Semiring , left : Union [Matrix , TransposedMatrix ], right : Union [Matrix , TransposedMatrix ]):
@@ -539,7 +538,9 @@ def _build_mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union
539538 perm_out = ir .AffineMap .get (3 , 0 , [ir .AffineDimExpr .get (0 ), ir .AffineDimExpr .get (1 )])
540539 rtt_left = left .rtt .as_mlir_type ()
541540 rtt_right = right .rtt .as_mlir_type ()
542- rtt_out = left .rtt .copy (dtype = op_result_dtype , ordering = left .perceived_ordering ).as_mlir_type ()
541+ rtt_out = RankedTensorType (dtype = op_result_dtype ,
542+ sparsity = determine_sparsity (left , right ),
543+ ordering = left .perceived_ordering ).as_mlir_type ()
543544
544545 @func .FuncOp .from_py_func (rtt_left , rtt_right )
545546 def main (x , y ):
@@ -1223,18 +1224,18 @@ def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_siz
12231224 v = Vector .new (tensor .dtype , row_size )
12241225 # Map idx to output indices
12251226 idx = np .array (row_indices , dtype = np .uint64 )[idx ]
1226- v .build (idx , vals )
1227+ v .build (idx , vals , sparsity = tensor . _sparsity )
12271228 return v
12281229 # Assign Vector as row or column of Matrix
12291230 m = Matrix .new (tensor .dtype , row_size , col_size )
12301231 if type (row_indices ) is int :
12311232 # Map idx to output cols
12321233 colidx = idx if col_indices is None else np .array (col_indices , dtype = np .uint64 )[idx ]
1233- m .build ([row_indices ]* len (vals ), colidx , vals )
1234+ m .build ([row_indices ]* len (vals ), colidx , vals , sparsity = [ "compressed" , "compressed" ] )
12341235 if type (col_indices ) is int :
12351236 # Map idx to output rows
12361237 rowidx = idx if row_indices is None else np .array (row_indices , dtype = np .uint64 )[idx ]
1237- m .build (rowidx , [col_indices ]* len (vals ), vals )
1238+ m .build (rowidx , [col_indices ]* len (vals ), vals , sparsity = [ "compressed" , "compressed" ] )
12381239 return m
12391240
12401241 # Matrix input
@@ -1249,5 +1250,5 @@ def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_siz
12491250 if col_indices is not None :
12501251 colidx = np .array (col_indices , dtype = np .uint64 )[colidx ]
12511252 m = Matrix .new (tensor .dtype , row_size , col_size )
1252- m .build (rowidx , colidx , vals )
1253+ m .build (rowidx , colidx , vals , sparsity = [ "compressed" , "compressed" ] )
12531254 return m
0 commit comments