@@ -721,39 +721,59 @@ def struct_column(op, **kw):
721
721
ops .All : "all" ,
722
722
ops .Any : "any" ,
723
723
ops .ApproxMedian : "median" ,
724
- ops .Arbitrary : "first" ,
725
724
ops .Count : "count" ,
726
725
ops .CountDistinct : "n_unique" ,
727
- ops .First : "first" ,
728
- ops .Last : "last" ,
729
726
ops .Max : "max" ,
730
727
ops .Mean : "mean" ,
731
728
ops .Median : "median" ,
732
729
ops .Min : "min" ,
733
- ops .StandardDev : "std" ,
734
730
ops .Sum : "sum" ,
735
- ops .Variance : "var" ,
736
731
}
737
732
738
- for reduction in _reductions .keys ():
739
733
740
- @translate .register (reduction )
741
- def reduction (op , ** kw ):
742
- args = [
743
- translate (arg , ** kw )
744
- for name , arg in zip (op .argnames , op .args )
745
- if name not in ("where" , "how" )
746
- ]
734
+ def execute_reduction (op , ** kw ):
735
+ arg = translate (op .arg , ** kw )
736
+
737
+ if op .where is not None :
738
+ arg = arg .filter (translate (op .where , ** kw ))
739
+
740
+ method = _reductions [type (op )]
741
+
742
+ return getattr (arg , method )()
743
+
744
+
745
+ for cls in _reductions :
746
+ translate .register (cls , execute_reduction )
747
+
748
+
749
+ @translate .register (ops .First )
750
+ @translate .register (ops .Last )
751
+ @translate .register (ops .Arbitrary )
752
+ def execute_first_last (op , ** kw ):
753
+ arg = translate (op .arg , ** kw )
754
+
755
+ # polars doesn't ignore nulls by default for these methods
756
+ predicate = arg .is_not_null ()
757
+ if op .where is not None :
758
+ predicate &= translate (op .where , ** kw )
759
+
760
+ arg = arg .filter (predicate )
761
+
762
+ return arg .last () if isinstance (op , ops .Last ) else arg .first ()
747
763
748
- agg = _reductions [type (op )]
749
764
750
- predicates = [arg .is_not_null () for arg in args ]
751
- if (where := op .where ) is not None :
752
- predicates .append (translate (where , ** kw ))
765
+ @translate .register (ops .StandardDev )
766
+ @translate .register (ops .Variance )
767
+ def execute_std_var (op , ** kw ):
768
+ arg = translate (op .arg , ** kw )
769
+
770
+ if op .where is not None :
771
+ arg = arg .filter (translate (op .where , ** kw ))
772
+
773
+ method = "std" if isinstance (op , ops .StandardDev ) else "var"
774
+ ddof = 0 if op .how == "pop" else 1
753
775
754
- first , * rest = args
755
- method = operator .methodcaller (agg , * rest )
756
- return method (first .filter (reduce (operator .and_ , predicates )))
776
+ return getattr (arg , method )(ddof = ddof )
757
777
758
778
759
779
@translate .register (ops .Mode )
0 commit comments